Skip to content

Instantly share code, notes, and snippets.

@CBroz1
Created May 20, 2024 16:41
Show Gist options
  • Save CBroz1/268fe2b36613c46c19f515d41636942d to your computer and use it in GitHub Desktop.
Save CBroz1/268fe2b36613c46c19f515d41636942d to your computer and use it in GitHub Desktop.
import copy
import itertools
import operator
import os
import pickle
from collections import namedtuple
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy as sp
import spikeinterface as si
from spyglass.spikesorting import (
SortingviewWorkspace,
SpikeSorting,
SpikeSortingRecording,
)
from spyglass.spikesorting.spikesorting_curation import (
AutomaticCuration,
AutomaticCurationParameters,
AutomaticCurationSelection,
CuratedSpikeSorting,
CuratedSpikeSortingSelection,
Curation,
MetricParameters,
MetricSelection,
QualityMetrics,
WaveformParameters,
Waveforms,
WaveformSelection,
)
def overlap(x, y):
z = list(x) + list(y)
if not all(np.isfinite(z)):
raise Exception(f"All elements must be finite")
if not all(np.asarray(z) >= 0):
raise Exception(f"All elements must be nonnegative")
return (
2 * np.sum(np.min(np.vstack((x, y)), axis=0)) / (np.sum(x) + np.sum(y))
)
def add_colorbar(
img, fig, ax, cbar_location="left", size="5%", pad_factor=0.05
):
from mpl_toolkits.axes_grid1 import make_axes_locatable
divider = make_axes_locatable(ax)
pad_factor *= fig.get_size_inches()[0] # scale pad based on figure width
cax = divider.append_axes(cbar_location, size=size, pad=pad_factor)
cbar = fig.colorbar(img, ax=[ax], cax=cax)
return cbar
def make_param_name(param_values):
return "_1".join([str(x) for x in param_values])
def format_nwb_file_name(nwb_file_name):
return nwb_file_name.split("_")[0]
def get_google_spreadsheet(
service_account_dir,
service_account_json,
spreadsheet_key,
spreadsheet_tab_name,
):
scope = ["https://spreadsheets.google.com/feeds"]
# Change to directory with service account credentials json
os.chdir(service_account_dir)
# Get service account credentials from json file
service_account_credentials = (
ServiceAccountCredentials.from_json_keyfile_name(
service_account_json, scope
)
)
# Get spreadsheet
client_obj = gspread.authorize(service_account_credentials)
spreadsheet_obj = client_obj.open_by_key(spreadsheet_key)
worksheet = spreadsheet_obj.worksheet(spreadsheet_tab_name)
return worksheet.get_all_values()
def check_one_none(x, list_element_names=""):
num_none = len([x_i for x_i in x if x_i is None])
if num_none != 1:
raise Exception(
f"Need exactly one None in passed arguments {list_element_names}"
f" but got {num_none} Nones"
)
def get_curation_spreadsheet(subject_id, date, tolerate_no_notes=True):
# Get data from google spreadsheet
service_account_dir = ""
service_account_json = ""
spreadsheet_key = ""
spreadsheet_tab_name = f"{subject_id}{date}"
column_names = np.asarray(
[
"sort_group_id",
"unit_id_1",
"unit_id_2",
"merge_type",
"notes",
"label",
"potential action items",
]
)
try:
table = np.asarray(
get_google_spreadsheet(
service_account_dir,
service_account_json,
spreadsheet_key,
spreadsheet_tab_name,
)
)
except:
failure_message = f"Could not get google spreadsheet with curation notes for {subject_id}{date}"
if tolerate_no_notes:
print(failure_message)
return pd.DataFrame(columns=column_names)
else:
raise Exception(failure_message)
# Get labels as df
# ...Get index of row where labels start: the one after column names
row_idx = (
unpack_single_element(
np.where(np.prod(table == column_names, axis=1))[0]
)
+ 1
)
# ...Convert to dataframe. Here, convert datatype from string as appropriate
def _convert_curation_spreadsheet_dtype(row, column_name):
int_column_names = ["sort_group_id", "unit_id_1", "unit_id_2"]
bool_column_names = ["label"]
bool_map = {"yes": True, "no": False, "unsure": None}
if column_name in int_column_names:
return row.astype(int)
elif column_name in bool_column_names:
return [
bool_map[x.strip()] for x in row
] # strip whitespace and convert to bool variable
return row
return pd.DataFrame.from_dict(
{
column_name: _convert_curation_spreadsheet_dtype(row, column_name)
for column_name, row in zip(column_names, table[row_idx:].T)
}
)
def get_cluster_data_file_name(
nwb_file_name,
sort_interval_name,
sorter,
preproc_params_name,
curation_id,
sort_group_id=None,
target_region=None,
):
# Check inputs
check_one_none([sort_group_id, target_region])
electrode_text = target_region
if sort_group_id is not None:
electrode_text = sort_group_id
return make_param_name(
[
format_nwb_file_name(nwb_file_name),
sort_interval_name,
sorter,
preproc_params_name,
curation_id,
electrode_text,
]
)
def load_curation_data(
save_dir,
nwb_file_name,
sort_interval_name,
sorter="mountainsort4",
preproc_params_name="franklab_tetrode_hippocampus",
sort_group_ids=None,
target_region=None,
curation_id=1,
verbose=True,
overwrite_quantities=True,
):
# Check inputs
check_one_none([sort_group_ids, target_region])
if verbose:
print(f"Loading curation data for {nwb_file_name}...")
cd_make_if_nonexistent(save_dir)
if target_region is not None:
file_name_save = get_cluster_data_file_name(
nwb_file_name,
sort_interval_name,
sorter,
preproc_params_name,
curation_id,
target_region=target_region,
)
return pickle.load(open(file_name_save, "rb"))
sort_groups_data = dict()
for sort_group_id in sort_group_ids:
file_name_save = get_cluster_data_file_name(
nwb_file_name,
sort_interval_name,
sorter,
preproc_params_name,
curation_id,
sort_group_id=sort_group_id,
)
# Continue if file doesnt exist
if not os.path.exists(file_name_save):
continue
# Store sort group data
sort_groups_data[sort_group_id] = pickle.load(
open(file_name_save, "rb")
)
# TODO: overwrite files with version with this calculation, and then delete this from here
default_params = get_correlogram_default_params()
correlogram_max_dt, correlogram_min_dt = (
default_params["max_dt"],
default_params["min_dt"],
)
for data in sort_groups_data.values():
data["correlogram_isi_violation_ratios"] = (
get_correlogram_isi_violation_ratios(
data, max_dt=correlogram_max_dt, min_dt=correlogram_min_dt
)
)
# !!! TEMPORARY UNTIL MOVE THIS TO make_curation_data
if overwrite_quantities:
# Convert sort group ID from string to int
sort_groups_data = {
int(sort_group_id): sort_group_data
for sort_group_id, sort_group_data in sort_groups_data.items()
}
# Add unit names to sort group data
for sort_group_id, sort_group_data in sort_groups_data.items():
if "unit_ids" not in sort_group_data:
sort_group_data["unit_ids"] = list(
sort_group_data["spike_times"].keys()
)
# Get correlogram quantities
print("Loading correlogram quantities...")
for sort_group_id, sort_group_data in sort_groups_data.items():
sort_group_data["correlogram_asymmetries"] = (
get_correlogram_asymmetries(sort_group_data["correlograms"])
)
sort_group_data["correlogram_asymmetry_directions"] = (
get_correlogram_asymmetry_directions(
sort_group_data["correlograms"]
)
)
sort_group_data["correlogram_counts"] = get_correlogram_counts(
sort_group_data["correlograms"]
)
# Overwrite amplitude overlaps
print("Loading amplitude overlaps...")
for sort_group_id, sort_group_data in sort_groups_data.items():
sort_group_data["amplitude_overlaps"] = get_amplitude_overlaps(
sort_group_data
)
# Get amplitude size comparison
for sort_group_id, sort_group_data in sort_groups_data.items():
sort_group_data["amplitude_size_comparisons"] = (
get_amplitude_size_comparisons(sort_group_data)
)
# Get burst pair amplitude correlogram asymmetry metric
for sort_group_id, sort_group_data in sort_groups_data.items():
sort_group_data["burst_pair_amplitude_timing_bools"] = (
get_burst_pair_amplitude_timing_bools(sort_group_data)
)
# ISI violation percent
print("Loading ISI violation percent...")
for sort_group_id, sort_group_data in sort_groups_data.items():
sort_group_data["unit_pair_percent_isi_violations"] = (
get_unit_pair_percent_isi_violations(sort_group_data)
)
# Valid lower amplitude fractions
print("Loading valid lower amplitude fractions...")
for sort_group_id, sort_group_data in sort_groups_data.items():
print(f"on sort group {sort_group_id}...")
sort_group_data["valid_lower_amplitude_fractions"] = (
get_valid_lower_amplitude_fractions(sort_group_data)
)
sort_group_data["unit_merge_valid_lower_amplitude_fractions"] = (
get_unit_merge_valid_lower_amplitude_fractions(sort_group_data)
)
# !!!!!!!!
return {
"nwb_file_name": nwb_file_name,
"sort_interval_name": sort_interval_name,
"sorter": sorter,
"preproc_params_name": preproc_params_name,
"n_sort_groups": len(sort_group_ids),
"sort_groups": sort_groups_data,
}
def make_curation_data(
save_dir,
nwb_file_name,
sort_interval_name,
sorter="mountainsort4",
preproc_params_name="franklab_tetrode_hippocampus",
sort_group_ids=None,
curation_id=1,
get_workspace_url=True,
ignore_invalid_sort_group_ids=False,
overwrite_existing=False,
verbose=True,
):
# Check that key specific enough (each sort group represented no more than once)
key = {
"nwb_file_name": nwb_file_name,
"sort_interval_name": sort_interval_name,
"sorter": sorter,
"preproc_params_name": preproc_params_name,
"curation_id": curation_id,
}
valid_sort_group_ids = [
x for x in (SpikeSorting & key).fetch("sort_group_id")
]
check_all_unique(valid_sort_group_ids)
# Define sort group ids if not passed
if sort_group_ids is None:
sort_group_ids = valid_sort_group_ids
# Check that passed sort group ids are valid
if not ignore_invalid_sort_group_ids and not set(sort_group_ids).issubset(
set(valid_sort_group_ids)
):
raise ValueError(f"List of sort groups includes invalid sort group IDs")
# Get correlogram default params
default_params = get_correlogram_default_params()
correlogram_max_dt, correlogram_min_dt = (
default_params["max_dt"],
default_params["min_dt"],
)
# Loop through sort groups and make cluster data if does not exist or want to overwrite
for sort_group_id in sort_group_ids:
# Continue if sort group invalid and want to tolerate this
if (
ignore_invalid_sort_group_ids
and sort_group_id not in valid_sort_group_ids
):
continue
# Continue if file already exists and dont want to overwrite
file_name_save = get_cluster_data_file_name(
nwb_file_name,
sort_interval_name,
sorter,
preproc_params_name,
curation_id,
sort_group_id,
)
if (
os.path.exists(os.path.join(save_dir, file_name_save))
and not overwrite_existing
):
print(
f"Cluster data exists for {nwb_file_name}, sort group {sort_group_id}; continuing"
)
continue
# Otherwise, make cluster data
if verbose:
print(
f"Making cluster data for {nwb_file_name}, sort group {sort_group_id}"
)
# Get key from CuratedSpikeSorting since will need all fields (some of which were not defined
# by user, e.g. team_name) to populate other tables
sort_group_key = (
CuratedSpikeSorting & {**key, **{"sort_group_id": sort_group_id}}
).fetch1("KEY")
data = dict() # for cluster data
# Make key for getting whitened waveforms. Note that here we use a curation id of ZERO, regardless
# of what curation_id was passed
waveforms_key = copy.deepcopy(sort_group_key)
# waveforms_key.update({'curation_id': 0,
# 'waveform_params_name': 'RSN_whitened_float'})
waveforms_key.update(
{"curation_id": 0, "waveform_params_name": "5k_whitened_float_1"}
)
# Populate waveforms tables if no entry
if not (Waveforms & waveforms_key):
if verbose:
print(f"Populating Waveforms table with key {waveforms_key}...")
WaveformSelection.insert1(waveforms_key, skip_duplicates=True)
Waveforms.populate([(WaveformSelection & waveforms_key).proj()])
# Get workspace URL if indicated
if get_workspace_url:
data["workspace_url"] = SortingviewWorkspace().url(sort_group_key)
# Get timestamps
if verbose:
print(f"Getting timestamps...")
recording_path = (SpikeSortingRecording & sort_group_key).fetch1(
"recording_path"
)
recording = si.load_extractor(recording_path)
timestamps_raw = SpikeSortingRecording._get_recording_timestamps(
recording
)
# Get total recording duration in seconds
data["recording_duration"] = recording.get_total_duration()
# Get spikes data
if verbose:
print(f"Getting spikes data...")
# ...First get valid unit IDs, for the passed curation_id and metric restrictions.
# If not unit IDs for given sort group, continue
css_entry = unpack_single_element(
(
CuratedSpikeSorting
& {**sort_group_key, **{"curation_id": curation_id}}
).fetch_nwb()
)
if "units" not in css_entry:
continue
units_df = css_entry["units"]
valid_unit_ids = units_df.index
# ...Get unit metrics
metric_names = [
"snr",
"isi_violation",
"nn_isolation",
"nn_noise_overlap",
] # desired metrics
for metric_name in metric_names:
# Continue if metric name not in curated spike sorting entry
if metric_name not in units_df:
continue
data[metric_name] = units_df[metric_name].to_dict()
# ...Get waveform extractor, which will be used to get other quantities
we = (Waveforms & waveforms_key).load_waveforms(
waveforms_key
) # waveform extractor
data["sampling_frequency"] = we.sorting.get_sampling_frequency()
data["unit_ids"] = valid_unit_ids
data["n_clusters"] = len(valid_unit_ids)
data["n_channels"] = len(we.recording.get_channel_ids())
data["waveform_window"] = np.arange(-we.nbefore, we.nafter)
# IMPORTANT NOTE: WAVEFORMS AND SPIKE TIMES ARE SUBSAMPLED (SEEMS MAX IS AT 20000). Happens in line below.
waveform_data = {
unit_id: we.get_waveforms(unit_id, with_index=True)
for unit_id in valid_unit_ids
}
spike_samples = {
unit_id: we.sorting.get_unit_spike_train(unit_id=unit_id)
for unit_id in valid_unit_ids
}
# TODO: understand the line below
data["waveforms"] = {
unit_id: np.swapaxes(wv[0], 0, 2)
for unit_id, wv in waveform_data.items()
}
data["waveform_indices"] = {
unit_id: np.array(list(zip(*wv[1]))[0]).astype(int)
for unit_id, wv in waveform_data.items()
}
# ...Get spike times
data["spike_times"] = {
unit_id: timestamps_raw[samples[data["waveform_indices"][unit_id]]]
for unit_id, samples in spike_samples.items()
}
# ...Get average waveforms
data["average_waveforms"] = get_average_waveforms(data["waveforms"])
# ...Get peak channels
data["peak_channels"] = get_peak_channels(data["average_waveforms"])
# ...Get waveform amplitudes
data["amplitudes"] = get_waveform_amplitudes(data["waveforms"])
# ...Get amplitude size comparison
data["amplitude_size_comparisons"] = get_amplitude_size_comparisons(
data
)
# Get cosine similarity
if verbose:
print(f"Getting cosine similarities...")
data["cosine_similarities"] = get_cosine_similarities(
data["average_waveforms"]
)
# Get correlogram quantities
if verbose:
print(f"Getting correlograms...")
data["correlograms"] = get_correlograms(
data["spike_times"],
max_dt=correlogram_max_dt,
min_dt=correlogram_min_dt,
)
data["correlogram_isi_violation_ratios"] = (
get_correlogram_isi_violation_ratios(
data, max_dt=correlogram_max_dt, min_dt=correlogram_min_dt
)
)
data["correlogram_asymmetries"] = get_correlogram_asymmetries(
data["correlograms"]
)
data["correlogram_asymmetry_directions"] = (
get_correlogram_asymmetry_directions(data["correlograms"])
)
data["correlogram_counts"] = get_correlogram_counts(
data["correlograms"]
)
data["correlogram_min_dt"] = correlogram_max_dt
data["correlogram_max_dt"] = correlogram_max_dt
# Get amplitude overlap
if verbose:
print(f"Getting amplitude overlaps...")
data["amplitude_overlaps"] = get_amplitude_overlaps(data)
# Get burst pair amplitude correlogram asymmetry metric
data["burst_pair_amplitude_timing_bools"] = (
get_burst_pair_amplitude_timing_bools(data)
)
# Get ISI violation percent for merged unit pairs
data["unit_pair_percent_isi_violations"] = (
get_unit_pair_percent_isi_violations(data)
)
# Get amplitude decrement metrics
if verbose:
print(f"Getting amplitude decrement quantities...")
for max_dt in [0.015, 0.4]:
data[f"amplitude_decrements_{max_dt}"] = (
get_unit_amplitude_decrements(data, max_dt)
)
data[f"unit_merge_amplitude_decrements_{max_dt}"] = (
get_unit_merge_amplitude_decrements(data, max_dt)
)
data[f"amplitude_decrement_changes_{max_dt}"] = (
get_amplitude_decrement_changes(data, max_dt)
)
# Valid lower amplitude fractions
if verbose:
print(f"Getting valid lower amplitude fractions...")
data["valid_lower_amplitude_fractions"] = (
get_valid_lower_amplitude_fractions(data)
)
data["unit_merge_valid_lower_amplitude_fractions"] = (
get_unit_merge_valid_lower_amplitude_fractions(data)
)
# Save data
cd_make_if_nonexistent(save_dir)
if verbose:
print(f"Saving {file_name_save} in {save_dir}...")
pickle.dump(data, open(file_name_save, "wb")) # save data
# need to update this for each user
def get_curation_data_save_dir(subject_id):
return f"/cumulus/mcoulter/curation_data/{subject_id}"
def make_curation_data_wrapper(
subject_ids,
dates,
sort_interval_name="raw data valid times no premaze no home",
preproc_params_name="franklab_tetrode_hippocampus",
sorter="mountainsort4",
sort_group_ids=None,
get_workspace_url=False,
curation_id=1,
ignore_invalid_sort_group_ids=False,
overwrite_existing=False,
verbose=True,
):
# Make curation data
for subject_id, date in zip(subject_ids, dates):
# Get nwb file name
nwb_file_name = nwbf_name_from_subject_id_date(subject_id, date)
# Define directory to save data in
save_dir = get_curation_data_save_dir(subject_id)
# Make curation data
make_curation_data(
save_dir=save_dir,
nwb_file_name=nwb_file_name,
sort_interval_name=sort_interval_name,
sorter=sorter,
preproc_params_name=preproc_params_name,
sort_group_ids=sort_group_ids,
curation_id=curation_id,
get_workspace_url=get_workspace_url,
ignore_invalid_sort_group_ids=ignore_invalid_sort_group_ids,
overwrite_existing=overwrite_existing,
verbose=verbose,
)
def _compute_cluster_data(func_name, data_in):
data_out = {cluster: None for cluster in data_in.keys()}
for cluster, data in data_in.items():
data_out[cluster] = func_name(data)
return data_out
def _compute_pairwise_cluster_data(
func_name, data_in, nested_dict=False, kwargs=None
):
# Get inputs if not passed
if kwargs is None:
kwargs = {}
# Initialize output dictionary
data_out = {
cluster_1: {cluster_2: None for cluster_2 in data_in.keys()}
for cluster_1 in data_in.keys()
}
if nested_dict:
for cluster_1 in data_in.keys():
for cluster_2 in data_in[cluster_1].keys():
data_out[cluster_1][cluster_2] = func_name(
data_in[cluster_1][cluster_2], **kwargs
)
else:
for cluster_1, data_1 in data_in.items():
for cluster_2, data_2 in data_in.items():
data_out[cluster_1][cluster_2] = func_name(
data_1, data_2, **kwargs
)
return data_out
def _compute_average_waveform(wv):
wv_avg = np.mean(wv, axis=2)
return wv_avg
def _compute_peak_channel(wv_avg):
idx = np.argmax(_compute_waveform_amplitude(wv_avg))
return idx
def _compute_waveform_amplitude(wv):
amp = np.max(wv, axis=1) - np.min(wv, axis=1)
return amp
def _compute_amplitude_overlaps(data, unit_1, unit_2, bin_width=0.1):
# Find peak amplitude channel for each unit
unit_ids = [unit_1, unit_2]
peak_channels = np.unique(
[data["peak_channels"][unit_id] for unit_id in unit_ids]
)
# For each unique peak amplitude channel, find overlap of normalized histograms of
# amplitude distribution across units
overlaps = [] # overlap across peak channels
for peak_channel in peak_channels: # peak channels
# Get amplitudes for units
unit_amplitudes = np.asarray(
[
unpack_single_element(
_compute_waveform_amplitude(
data["waveforms"][unit_id][[peak_channel], :, :]
)
)
for unit_id in unit_ids
]
)
# Use minimum and maximum amplitude seen across units to form histogram bins
concatenated_unit_amplitudes = np.concatenate(unit_amplitudes)
bin_edges = np.arange(
np.min(concatenated_unit_amplitudes),
np.max(concatenated_unit_amplitudes) + bin_width,
bin_width,
)
# Find overlap between normalized histograms
overlaps.append(
overlap(
*[
np.histogram(amplitudes, bin_edges, density=True)[0]
for amplitudes in unit_amplitudes
]
)
)
# Take average of overlaps across unit peak amplitude channels
return np.mean(overlaps)
def _compare_amplitude_size(data, unit_1, unit_2):
unit_1_mean = np.mean(
data["amplitudes"][unit_1][data["peak_channels"][unit_1]]
)
unit_2_mean = np.mean(
data["amplitudes"][unit_2][data["peak_channels"][unit_1]]
)
if unit_1_mean < unit_2_mean:
return -1
if unit_1_mean == unit_2_mean:
return 0
if unit_1_mean > unit_2_mean:
return 1
def _compute_cosine_similarity(wv_avg_1, wv_avg_2):
wv_avg_1, wv_avg_2 = (np.ravel(wv_avg) for wv_avg in (wv_avg_1, wv_avg_2))
wv_avg_nrm_1, wv_avg_nrm_2 = (
wv_avg / np.linalg.norm(wv_avg, axis=0)
for wv_avg in (wv_avg_1, wv_avg_2)
)
sim = np.dot(wv_avg_nrm_1, wv_avg_nrm_2)
return sim
def get_correlogram_default_params():
return {"max_dt": 0.5, "min_dt": 0}
def _compute_correlogram(spk_times_1, spk_times_2, max_dt=None, min_dt=None):
# Get inputs if not passed
if max_dt is None:
max_dt = get_correlogram_default_params()["max_dt"]
if min_dt is None:
min_dt = get_correlogram_default_params()["min_dt"]
time_diff = (
np.tile(spk_times_1, (spk_times_2.size, 1)) - spk_times_2[:, np.newaxis]
)
ind = np.logical_and(
np.abs(time_diff) > min_dt, np.abs(time_diff) <= max_dt
)
time_diff = np.sort(time_diff[ind])
return time_diff
def _compute_correlogram_count(
time_diff, min_dt=-200 / 1000, max_dt=200 / 1000
):
return np.sum(np.logical_and(time_diff > min_dt, time_diff < max_dt))
def _compute_correlogram_asymmetry_direction(
time_diff, min_dt=-200 / 1000, max_dt=200 / 1000
):
neg_count = np.sum(np.logical_and(time_diff > min_dt, time_diff < 0))
pos_count = np.sum(np.logical_and(time_diff > 0, time_diff < max_dt))
if neg_count > pos_count:
return -1
if neg_count == pos_count:
return 0
if pos_count > neg_count:
return 1
def _compute_correlogram_asymmetry(
time_diff, min_dt=-200 / 1000, max_dt=200 / 1000
):
zero_count = np.sum(time_diff == 0)
neg_count = np.sum(np.logical_and(time_diff > min_dt, time_diff < 0))
pos_count = np.sum(np.logical_and(time_diff > 0, time_diff < max_dt))
asym = (np.max([neg_count, pos_count]) + zero_count / 2) / (
zero_count + neg_count + pos_count
)
return asym
def percent_isi_violations(spike_train, isi_threshold):
isis = np.diff(spike_train)
num_isi_violations = np.sum(isis < isi_threshold)
return 100 * num_isi_violations / len(isis)
def _compute_correlogram_isi_violation_ratio(
correlogram, correlogram_window_width, isi_threshold=None
):
# Get inputs if not passed
if isi_threshold is None:
isi_threshold = 0.0015
# Find violations in correlogram
invalid_bool = abs(correlogram) < isi_threshold
# Compute fraction of correlogram that is violations
correlogram_isi_violation = np.sum(invalid_bool) / len(invalid_bool)
# Calculate expected violation ratio if correlogram uniform
uniform_violation = (isi_threshold * 2) * correlogram_window_width
# Return ratio of actual violation ratio to ratio expected if uniform correlogram
return correlogram_isi_violation / uniform_violation
def _burst_pair_amplitude_timing_bool(data, unit_1, unit_2):
return (
data["amplitude_size_comparisons"][unit_1][unit_2]
* data["correlogram_asymmetry_directions"][unit_1][unit_2]
< 0
)
# Amplitude decrement
def _channel_amplitudes(data, unit_id, channel):
return unpack_single_element(
_compute_waveform_amplitude(data["waveforms"][unit_id][[channel], :, :])
)
def _time_diff(spike_times, min_dt, max_dt):
time_diff = (
np.tile(spike_times, (spike_times.size, 1)) - spike_times[:, np.newaxis]
)
ind = np.logical_and(
np.abs(time_diff) > min_dt, np.abs(time_diff) <= max_dt
)
return time_diff, ind
# slow step??
# this takes about 10 secs - seems like is has to run 4x for each cluster pair
# slow steps 0->1 and 1->2. note: this is much slower for tetrodes with many spikes
# if you reduce spikes to 5k from 20k, speed up is 10-16x!
def _compute_amplitude_decrement(spike_times, amplitudes, max_dt=None):
# print('_compute_unit_decrement',datetime.datetime.now())
# Get inputs if not passed
if max_dt is None:
max_dt = 0.015
# print('_compute_unit_decrement 0',datetime.datetime.now(),len(spike_times))
time_diff, ind = _time_diff(spike_times, 0, max_dt)
# print('_compute_unit_decrement 1',datetime.datetime.now())
amplitude_diff = (
np.tile(amplitudes, (amplitudes.size, 1)) - amplitudes[:, np.newaxis]
)
# print('_compute_unit_decrement 2',datetime.datetime.now())
valid_time_diff = time_diff[ind]
# print('_compute_unit_decrement 3',datetime.datetime.now())
valid_amplitude_diff = amplitude_diff[ind]
# print('_compute_unit_decrement 4',datetime.datetime.now())
# Return nan if fewer than two valid samples, since in this case cannot calculate correlation
if len(valid_time_diff) < 2:
return np.nan
return sp.stats.pearsonr(valid_time_diff, valid_amplitude_diff)[0]
def _compute_unit_amplitude_decrement(data, unit_id, max_dt=None):
return _compute_amplitude_decrement(
spike_times=data["spike_times"][unit_id],
amplitudes=_channel_amplitudes(
data, unit_id, data["peak_channels"][unit_id]
),
max_dt=max_dt,
)
# this is the slow step: runs computation 4x per cluster pair
# maybe this step could be parallelized because the same computation is run 4 times
def _compute_unit_merge_amplitude_decrement(data, unit_1, unit_2, max_dt=None):
unit_ids = [unit_1, unit_2]
# print('_compute_unit_merge_amplitude_decrement',datetime.datetime.now(),'units',unit_1,unit_2)
return np.mean(
[
_compute_amplitude_decrement(
spike_times=np.concatenate(
[data["spike_times"][unit_id] for unit_id in unit_ids]
),
amplitudes=np.concatenate(
[
_channel_amplitudes(
data,
unit_id,
data["peak_channels"][peak_channel_unit_id],
)
for unit_id in unit_ids
]
),
max_dt=max_dt,
)
for peak_channel_unit_id in unit_ids
]
)
def _compute_amplitude_decrement_change(data, unit_1, unit_2, max_dt):
# print('_compute_amplitude_decrement_change',datetime.datetime.now())
# Get average amplitude decrement across the two units
unit_amplitude_decrement = np.mean(
[
data[f"amplitude_decrements_{max_dt}"][unit_id]
for unit_id in [unit_1, unit_2]
]
)
# Get amplitude decrement metric for merged case
unit_merge_amplitude_decrement = data[
f"unit_merge_amplitude_decrements_{max_dt}"
][unit_1][unit_2]
return unit_merge_amplitude_decrement - unit_amplitude_decrement
# Valid lower amplitude fraction
def _compute_valid_lower_amplitude_fraction(
spike_times, amplitudes, percentile=None, valid_window=None
):
# Get inputs if not passed
if percentile is None:
percentile = 5
if valid_window is None:
valid_window = 0.5
# Get data value at passed percentile
threshold = np.percentile(amplitudes, percentile)
# Threshold data
below_threshold_spike_times = spike_times[amplitudes < threshold]
above_threshold_spike_times = spike_times[amplitudes >= threshold]
# Find fraction of lower amplitude spikes that have an upper amplitude spike within some amount of time
below_threshold_tile = np.tile(
below_threshold_spike_times, (len(above_threshold_spike_times), 1)
)
above_threshold_tile = np.tile(
above_threshold_spike_times, (len(below_threshold_spike_times), 1)
).T
spike_time_differences = above_threshold_tile - below_threshold_tile
valid_bool = np.sum(abs(spike_time_differences) < valid_window, axis=0) > 0
# Return quantities
fraction_lower_amplitude_valid = np.sum(valid_bool) / len(valid_bool)
valid_lower_amplitude_spike_times = below_threshold_spike_times[valid_bool]
valid_lower_amplitudes = amplitudes[amplitudes < threshold][valid_bool]
return (
fraction_lower_amplitude_valid,
valid_lower_amplitude_spike_times,
valid_lower_amplitudes,
)
def _compute_unit_merge_valid_lower_amplitude_fraction(
data, unit_1, unit_2, percentile=None, valid_window=None
):
unit_ids = [unit_1, unit_2]
return np.mean(
[
_compute_valid_lower_amplitude_fraction(
spike_times=np.concatenate(
[data["spike_times"][unit_id] for unit_id in unit_ids]
),
amplitudes=np.concatenate(
[
_channel_amplitudes(
data,
unit_id,
data["peak_channels"][peak_channel_unit_id],
)
for unit_id in unit_ids
]
),
percentile=percentile,
valid_window=valid_window,
)[0]
for peak_channel_unit_id in unit_ids
]
)
# Get quantities
def get_average_waveforms(waveforms):
return _compute_cluster_data(_compute_average_waveform, waveforms)
def get_peak_channels(average_waveforms):
return _compute_cluster_data(_compute_peak_channel, average_waveforms)
def get_waveform_amplitudes(waveforms):
return _compute_cluster_data(_compute_waveform_amplitude, waveforms)
def get_amplitude_overlaps(data):
unit_ids = data["unit_ids"]
return {
unit_1: {
unit_2: _compute_amplitude_overlaps(data, unit_1, unit_2)
for unit_2 in unit_ids
}
for unit_1 in unit_ids
}
def get_amplitude_size_comparisons(data):
unit_ids = data["unit_ids"]
return {
unit_1: {
unit_2: _compare_amplitude_size(data, unit_1, unit_2)
for unit_2 in unit_ids
}
for unit_1 in unit_ids
}
def get_cosine_similarities(average_waveforms):
return _compute_pairwise_cluster_data(
_compute_cosine_similarity, average_waveforms
)
def get_correlograms(spike_times, max_dt=None, min_dt=None):
return _compute_pairwise_cluster_data(
_compute_correlogram,
spike_times,
kwargs={"max_dt": max_dt, "min_dt": min_dt},
)
def get_correlogram_counts(spike_time_differences, kwargs=None):
return _compute_pairwise_cluster_data(
_compute_correlogram_count,
spike_time_differences,
nested_dict=True,
kwargs=kwargs,
)
def get_correlogram_asymmetries(spike_time_differences, kwargs=None):
return _compute_pairwise_cluster_data(
_compute_correlogram_asymmetry,
spike_time_differences,
nested_dict=True,
kwargs=kwargs,
)
def get_correlogram_asymmetry_directions(spike_time_differences, kwargs=None):
return _compute_pairwise_cluster_data(
_compute_correlogram_asymmetry_direction,
spike_time_differences,
nested_dict=True,
kwargs=kwargs,
)
def get_burst_pair_amplitude_timing_bools(data):
unit_ids = data["unit_ids"]
return {
unit_1: {
unit_2: _burst_pair_amplitude_timing_bool(data, unit_1, unit_2)
for unit_2 in unit_ids
}
for unit_1 in unit_ids
}
def merge_spike_times(data, unit_1, unit_2):
return np.sort(
np.concatenate(
(data["spike_times"][unit_1], data["spike_times"][unit_2])
)
)
def get_unit_pair_percent_isi_violations(data, isi_threshold=0.0015):
unit_ids = data["unit_ids"]
return {
unit_1: {
unit_2: percent_isi_violations(
merge_spike_times(data, unit_1, unit_2), isi_threshold
)
for unit_2 in unit_ids
}
for unit_1 in unit_ids
}
def get_correlogram_isi_violation_ratios(
data, max_dt, min_dt, isi_threshold=None
):
unit_ids = data["unit_ids"]
correlogram_window_width = max_dt * 2 - min_dt * 2
return {
unit_1: {
unit_2: _compute_correlogram_isi_violation_ratio(
correlogram=data["correlograms"][unit_1][unit_2],
correlogram_window_width=correlogram_window_width,
isi_threshold=isi_threshold,
)
for unit_2 in unit_ids
}
for unit_1 in unit_ids
}
def get_unit_amplitude_decrements(data, max_dt=None):
return {
unit_id: _compute_unit_amplitude_decrement(data, unit_id, max_dt)
for unit_id in data["unit_ids"]
}
def get_unit_merge_amplitude_decrements(data, max_dt=None):
unit_ids = data["unit_ids"]
# print('cluster pair',unit_1,unit_2)
return {
unit_1: {
unit_2: _compute_unit_merge_amplitude_decrement(
data, unit_1, unit_2, max_dt
)
for unit_2 in unit_ids
}
for unit_1 in unit_ids
}
def get_amplitude_decrement_changes(data, max_dt):
unit_ids = data["unit_ids"]
return {
unit_1: {
unit_2: _compute_amplitude_decrement_change(
data, unit_1, unit_2, max_dt
)
for unit_2 in unit_ids
}
for unit_1 in unit_ids
}
def get_valid_lower_amplitude_fractions(
data, percentile=None, valid_window=None
):
return {
unit_id: _compute_valid_lower_amplitude_fraction(
spike_times=data["spike_times"][unit_id],
amplitudes=data["amplitudes"][unit_id][
data["peak_channels"][unit_id], :
],
percentile=percentile,
valid_window=valid_window,
)[0]
for unit_id in data["unit_ids"]
}
def get_unit_merge_valid_lower_amplitude_fractions(
data, percentile=None, valid_window=None
):
unit_ids = data["unit_ids"]
return {
unit_1: {
unit_2: _compute_unit_merge_valid_lower_amplitude_fraction(
data,
unit_1,
unit_2,
percentile=percentile,
valid_window=valid_window,
)
for unit_2 in unit_ids
}
for unit_1 in unit_ids
}
# Analysis
def get_merge_candidates(cluster_data, threshold_sets, sort_group_ids=None):
# Get inputs if not passed
if sort_group_ids is None:
sort_group_ids = list(cluster_data["sort_groups"].keys())
# Loop through sort group IDs and apply thresholds to get merge candidates
merge_candidates_map = {
threshold_set_name: [] for threshold_set_name in threshold_sets.keys()
}
for sort_group_id in sort_group_ids:
data = cluster_data["sort_groups"][sort_group_id]
# Apply threshold sets
valid_bool_map = get_above_threshold_matrix_indices(
cluster_data, sort_group_id, threshold_sets
)
for (
threshold_set_name,
valid_bool,
) in valid_bool_map.items(): # threshold sets
# Find indices in array corresponding to merge candidates
merge_candidate_idxs = list(zip(*np.where(valid_bool)))
# Convert merge candidate indices in array to unit IDs
merge_candidates_map[threshold_set_name] += [
tuple(
[sort_group_id]
+ list(np.asarray(data["unit_ids"])[np.asarray(idxs)])
)
for idxs in merge_candidate_idxs
]
return merge_candidates_map
def merge_plots_wrapper(
cluster_data,
threshold_sets,
fig_scale=0.8,
subplot_width=4,
subplot_height=3,
plot_merge_candidates=None,
):
# Define plot parameters
num_rows = 2
num_columns = 4
gridspec_kw = {"width_ratios": [1, 1, 4, 4]}
for sort_group_id, data in cluster_data["sort_groups"].items():
# Apply threshold sets
valid_bool_map = get_above_threshold_matrix_indices(
cluster_data, sort_group_id, threshold_sets
)
# Continue of no passed merged candidates have current sort group
if plot_merge_candidates is not None:
if sort_group_id not in [x[0] for x in plot_merge_candidates]:
continue
# Plot matrices with pairwise metrics relevant for merging
plot_merge_matrices(
cluster_data,
sort_group_id,
valid_bool_map,
threshold_sets,
fig_scale=fig_scale,
)
# For threshold sets, plot metrics for merge candidates
for (
threshold_name,
valid_bool,
) in valid_bool_map.items(): # threshold sets
# Find indices in array corresponding to merge candidates
merge_candidate_idxs = list(zip(*np.where(valid_bool)))
# Convert merge candidate indices in array to unit IDs
merge_candidates = [
tuple(np.asarray(data["unit_ids"])[np.asarray(idxs)])
for idxs in merge_candidate_idxs
]
# Loop through merge candidates and plot metrics
unit_colors = ["crimson", "#2196F3"]
for unit_1, unit_2 in merge_candidates: # units
if plot_merge_candidates is not None:
if (
sort_group_id,
unit_1,
unit_2,
) not in plot_merge_candidates:
continue
# Initialize figure
fig, axes = plt.subplots(
num_rows,
num_columns,
figsize=(
num_columns * subplot_width,
num_rows * subplot_height,
),
gridspec_kw=gridspec_kw,
)
# Use peak channel of first unit to display data from both units
peak_ch = data["peak_channels"][unit_1]
# Leftmost subplots: average waveforms
for unit_id_idx, unit_id in enumerate([unit_1, unit_2]):
title = f"{sort_group_id}_{unit_id}"
if unit_id_idx == 1:
cosine_similarity = data["cosine_similarities"][unit_1][
unit_2
]
title = f"cosine similarity: {cosine_similarity: .2f}\n{title}"
gs = axes[0, unit_id_idx].get_gridspec()
# Remove underlying axis
for row_num in np.arange(0, num_rows):
axes[row_num, unit_id_idx].remove()
ax = fig.add_subplot(gs[:, unit_id_idx])
# ax = axes[0, unit_id_idx]
plot_average_waveforms(
cluster_data,
sort_group_id,
unit_id,
title=title,
color=unit_colors[unit_id_idx],
ax=ax,
)
# Second subplot: amplitude distributions
ax = axes[0, 2]
for unit_id_idx, unit_id in enumerate([unit_1, unit_2]):
title = f"amplitude overlap: {data['amplitude_overlaps'][unit_1][unit_2]: .3f}"
plot_amplitude_distribution(
cluster_data,
sort_group_id,
unit_id,
ch=peak_ch,
max_amplitude=None,
amplitude_bin_size=2,
histtype="step",
density=True,
label=f"{sort_group_id}_{unit_id}",
color=unit_colors[unit_id_idx],
title=title,
ax=ax,
)
# Third subplot: correlograms
ax = axes[0, 3]
plot_correlogram(
cluster_data,
sort_group_id,
unit_1,
unit_2,
max_time_difference=200 / 1000,
color="gray",
ax=ax,
)
# Fourth subplot: amplitudes over time
# Use peak channel from first unit to plot amplitudes for both units
gs = axes[1, 2].get_gridspec()
# Remove underlying axes
for ax in axes[1, 2:]:
ax.remove()
ax = fig.add_subplot(gs[1, 2:])
# Plot amplitudes over time for each unit
for unit_id_idx, unit_id in enumerate([unit_1, unit_2]):
plot_amplitude(
cluster_data,
sort_group_id,
unit_id,
peak_ch,
color=unit_colors[unit_id_idx],
ax=ax,
)
# Global title
fig.suptitle(
f"{sort_group_id}_{unit_1} vs. {sort_group_id}_{unit_2}\n{threshold_name}",
fontsize=20,
)
fig.tight_layout()
# VISUALIZATION
def plot_amplitude(
cluster_data, sort_group_id, unit_id, ch, color="black", ax=None
):
# Get inputs if not passed
if ax is None:
_, ax = plt.subplots()
# Plot amplitudes over time
data = cluster_data["sort_groups"][sort_group_id]
ax.scatter(
data["spike_times"][unit_id],
data["amplitudes"][unit_id][ch, :],
s=1,
color=color,
)
def _matrix_grid(ax, n_clusters, fig_scale):
for ndx in range(n_clusters - 1):
ax.axvline(x=ndx + 1, color="#FFFFFF", linewidth=fig_scale * 0.5)
ax.axhline(y=ndx + 1, color="#FFFFFF", linewidth=fig_scale * 0.5)
def _format_matrix_ax(ax, ticks, ticklabels, fig_scale, title):
ax.set_xticks(ticks)
ax.set_yticks(ticks)
ax.set_xticklabels(ticklabels)
ax.set_yticklabels(ticklabels)
ax.tick_params(length=0)
for spine in ax.spines.values():
spine.set_visible(False)
ax.set_title(title, fontsize=fig_scale * 12)
def get_above_threshold_matrix_indices(
cluster_data, sort_group_id, threshold_sets
):
return {
threshold_name: _apply_metric_matrix_thresholds(
cluster_data, sort_group_id, threshold_set.thresholds
)
for threshold_name, threshold_set in threshold_sets.items()
}
def _highlight_matrix_indices(valid_bool_map, threshold_sets, ax):
for threshold_name, valid_bool in valid_bool_map.items():
ii, jj = np.where(valid_bool)
for ndx in range(np.sum(valid_bool)):
ax.add_patch(
matplotlib.patches.Rectangle(
(jj[ndx], ii[ndx]),
1,
1,
edgecolor=threshold_sets[threshold_name].color,
fill=False,
lw=threshold_sets[threshold_name].lw,
zorder=2 * len(valid_bool) ** 2,
clip_on=False,
)
)
def _get_metric_matrix(
cluster_data,
sort_group_id,
metric_name,
apply_upper_diagonal_mask=False,
mask_value=np.nan,
):
data = cluster_data["sort_groups"][sort_group_id]
metric_dict = data[metric_name]
index = metric_dict.keys()
matrix = np.array(
[[metric_dict[ii][jj] for jj in index] for ii in index]
).astype(
np.float
) # float so can mask with nan
# Mask upper diagonal if indicated
if apply_upper_diagonal_mask:
matrix = mask_upper_diagonal(matrix, mask_value=mask_value)
return pd.DataFrame(matrix, index=index, columns=index)
def _apply_metric_matrix_thresholds(
cluster_data, sort_group_id, threshold_objs
):
return np.prod(
[
threshold_obj.threshold_direction(
_get_metric_matrix(
cluster_data,
sort_group_id,
threshold_obj.metric_name,
apply_upper_diagonal_mask=True,
mask_value=np.nan,
),
threshold_obj.threshold_value,
)
for threshold_obj in threshold_objs
],
axis=0,
)
def plot_amplitude_overlap_matrix(
cluster_data,
sort_group_id,
fig_scale=1,
fig_ax_list=None,
plot_color_bar=True,
):
data = cluster_data["sort_groups"][sort_group_id]
n_clusters = data["n_clusters"]
# Get amplitude overlap matrix
ao_matrix = _get_metric_matrix(
cluster_data,
sort_group_id,
"amplitude_overlaps",
apply_upper_diagonal_mask=True,
mask_value=0,
)
# Unpack figure and axis if passed
if fig_ax_list is not None:
fig, ax = fig_ax_list
# Otherwise make these
else:
fig = plt.figure(figsize=(n_clusters / 2, n_clusters / 2) * fig_scale)
gs = fig.add_gridspec(1, 1)
ax = fig.add_subplot(gs[0])
pcm = plt.pcolormesh(ao_matrix, cmap="inferno", vmin=0, vmax=1)
_matrix_grid(ax, n_clusters, fig_scale)
label = "".join(
(
cluster_data["nwb_file_name"],
"\n",
"interval: ",
cluster_data["sort_interval_name"],
"\n",
f"sort group: {sort_group_id}",
"\n" "amplitude overlap",
)
)
_format_matrix_ax(
ax,
ticks=np.arange(0.5, n_clusters + 0.5),
ticklabels=ao_matrix.index,
fig_scale=fig_scale,
title=label,
)
# Color bar
if plot_color_bar:
add_colorbar(pcm, fig, ax)
return fig, ax
def plot_merge_matrices(
cluster_data,
sort_group_id,
valid_bool_map,
threshold_sets,
fig_scale=1,
plot_color_bar=True,
):
data = cluster_data["sort_groups"][sort_group_id]
n_clusters = data["n_clusters"]
# Get cosine similarity matrix
cs_matrix = _get_metric_matrix(
cluster_data,
sort_group_id,
"cosine_similarities",
apply_upper_diagonal_mask=True,
mask_value=0,
)
# Get correlogram asymmetry matrix
ca_matrix = _get_metric_matrix(
cluster_data,
sort_group_id,
"correlogram_asymmetries",
apply_upper_diagonal_mask=True,
mask_value=0,
)
# Initialize figure
num_columns = 3
fig = plt.figure(
figsize=(
fig_scale * (num_columns * n_clusters / 2 + 2),
fig_scale * (n_clusters / 2),
)
)
width_ratios = [n_clusters / 2] * 3
gs = fig.add_gridspec(1, num_columns, wspace=0.2, width_ratios=width_ratios)
# Ticks across plots
ticks = np.arange(0.5, n_clusters + 0.5)
# First subplot: cosine similarity
ax = fig.add_subplot(gs[0])
pcm = plt.pcolormesh(cs_matrix, cmap="inferno", vmin=0, vmax=1)
# Grid
_matrix_grid(ax, n_clusters, fig_scale)
# Highlight indices crossing metric thresholds
_highlight_matrix_indices(valid_bool_map, threshold_sets, ax)
# Axis
_format_matrix_ax(
ax=ax,
ticks=ticks,
ticklabels=cs_matrix.index,
fig_scale=fig_scale,
title="cosine similarity",
)
# Color bar
if plot_color_bar:
add_colorbar(pcm, fig, ax)
# Second subplot: correlogram asymmetry
ax = fig.add_subplot(gs[1])
pcm = plt.pcolormesh(ca_matrix, cmap="inferno", vmin=0.5, vmax=1)
# Grid
_matrix_grid(ax, n_clusters, fig_scale)
# Highlight indices crossing metric thresholds
_highlight_matrix_indices(valid_bool_map, threshold_sets, ax)
# Axis
_format_matrix_ax(
ax=ax,
ticks=ticks,
ticklabels=ca_matrix.index,
fig_scale=fig_scale,
title="correlogram asymmetry",
)
# Color bar
if plot_color_bar:
add_colorbar(pcm, fig, ax)
# Third subplot: amplitude overlap
ax = fig.add_subplot(gs[2])
fig, ax = plot_amplitude_overlap_matrix(
cluster_data,
sort_group_id,
fig_scale=fig_scale,
fig_ax_list=[fig, ax],
plot_color_bar=plot_color_bar,
)
# Highlight indices crossing metric thresholds
_highlight_matrix_indices(valid_bool_map, threshold_sets, ax)
plt.show()
def plot_average_waveforms(
cluster_data,
sort_group_id,
unit_id,
amplitude_range=80,
trace_offset=40,
color="#2196F3",
title=None,
ax=None,
):
# Get inputs if not passed
if ax is None:
_, ax = plt.subplots()
if title is None:
title = f"{sort_group_id}_{unit_id}"
data = cluster_data["sort_groups"][sort_group_id]
n_channels = data["n_channels"]
n_points = data["waveform_window"].size
ax.axvline(x=n_points / 2, color="#9E9E9E", linewidth=1)
offset = np.tile(-np.arange(n_channels) * trace_offset, (n_points, 1))
wv_avg = data["average_waveforms"][unit_id]
trace = wv_avg.T + offset
peak_ind = np.full(n_channels, False)
peak_ind[data["peak_channels"][unit_id]] = True
ax.plot(trace[:, ~peak_ind], color=color, linewidth=1, clip_on=False)
ax.plot(trace[:, peak_ind], color=color, linewidth=2.5, clip_on=False)
ax.set_xlim([0, n_points])
ax.set_ylim(
[
-2 * amplitude_range / 3 - (n_channels - 1) * trace_offset,
amplitude_range / 3,
]
)
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
ax.set_title(title, fontsize=12)
def plot_average_waveforms_wrapper(
cluster_data, amplitude_range=80, trace_offset=40
):
for sort_group_id, data in cluster_data["sort_groups"].items():
n_clusters = data["n_clusters"]
fig = plt.figure(figsize=(n_clusters + 2, 1))
width_ratios = np.ones(n_clusters + 1)
width_ratios[0] = 2
gs = fig.add_gridspec(
1, n_clusters + 1, wspace=0.1, width_ratios=width_ratios
)
ax = fig.add_subplot(gs[0])
label = "".join(
(
cluster_data["nwb_file_name"],
"\n",
"interval: ",
cluster_data["sort_interval_name"],
"\n",
f"sort group: {sort_group_id}",
)
)
ax.text(-0.3, 0.3, label, multialignment="left", fontsize=12)
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
for spine in ax.spines.values():
spine.set_visible(False)
for ndx, unit_id in data["unit_ids"]:
ax = fig.add_subplot(gs[ndx + 1])
plot_average_waveforms(
cluster_data,
sort_group_id,
unit_id,
amplitude_range=amplitude_range,
trace_offset=trace_offset,
ax=ax,
)
plt.show()
def plot_amplitude_distribution(
cluster_data,
sort_group_id,
unit_id,
ch=None,
max_amplitude=None,
amplitude_bin_size=2,
density=False,
histtype=None,
color="#2196F3",
label=None,
title=None,
remove_axes=False,
ax=None,
):
data = cluster_data["sort_groups"][sort_group_id]
# Define channel if not passed
if ch is None:
ch = data["peak_channels"][unit_id]
amp = data["amplitudes"][unit_id][ch, :]
# Get inputs if not passed
if ax is None:
_, ax = plt.subplots()
if max_amplitude is None:
max_amplitude = np.max(amp)
if title is None:
title = f"{sort_group_id}_{unit_id}"
bin_edges = np.arange(
0, max_amplitude + amplitude_bin_size, amplitude_bin_size
)
ax.hist(
amp,
bin_edges,
density=density,
histtype=histtype,
color=color,
label=label,
linewidth=4,
alpha=0.9,
)
# ax.set_xlim([bin_edges[0], bin_edges[-1]])
if remove_axes:
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
ax.set_title(title, fontsize=12)
def plot_amplitude_distributions(
cluster_data, max_amplitude=50, amplitude_bin_size=2
):
for sort_group_id, data in cluster_data["sort_groups"].items():
n_clusters = data["n_clusters"]
fig = plt.figure(figsize=(n_clusters + 2, 1))
width_ratios = np.ones(n_clusters + 1)
width_ratios[0] = 2
gs = fig.add_gridspec(
1, n_clusters + 1, wspace=0.1, width_ratios=width_ratios
)
ax = fig.add_subplot(gs[0])
label = "".join(
(
cluster_data["nwb_file_name"],
"\n",
"interval: ",
cluster_data["sort_interval_name"],
"\n",
f"sort group: {sort_group_id}",
)
)
ax.text(-0.3, 0.3, label, multialignment="center", fontsize=12)
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
for spine in ax.spines.values():
spine.set_visible(False)
for ndx, (unit_id, time_diff) in enumerate(data["amplitudes"].items()):
ax = fig.add_subplot(gs[ndx + 1])
plot_amplitude_distribution(
cluster_data,
sort_group_id,
unit_id,
max_amplitude=max_amplitude,
amplitude_bin_size=amplitude_bin_size,
ax=ax,
)
plt.show()
def plot_correlogram(
cluster_data,
sort_group_id,
cluster_1,
cluster_2,
max_time_difference=20 / 1000,
time_bin_size=1 / 1000,
color="#2196F3",
remove_axes=False,
ax=None,
):
# Get inputs if not passed
if ax is None:
_, ax = plt.subplots()
data = cluster_data["sort_groups"][sort_group_id]
time_diff = data["correlograms"][cluster_1][cluster_2]
bin_edges = np.arange(
-max_time_difference, max_time_difference + time_bin_size, time_bin_size
)
ax.hist(time_diff, bin_edges, color=color)
ax.set_xlim([bin_edges[0], bin_edges[-1]])
ax.set_ylim([0, np.max(np.histogram(time_diff, bin_edges)[0])])
if remove_axes:
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
burst_pair_amplitude_timing_bool = data[
"burst_pair_amplitude_timing_bools"
][cluster_1][cluster_2]
correlogram_asymmetry = data["correlogram_asymmetries"][cluster_1][
cluster_2
]
isi_violation = data["unit_pair_percent_isi_violations"][cluster_1][
cluster_2
]
correlogram_count = int(data["correlogram_counts"][cluster_1][cluster_2])
ax.set_title(
f"{sort_group_id}_{cluster_1} vs {sort_group_id}_{cluster_2}"
f"\ncount: {correlogram_count: .2f}"
f"\nasymmetry: {correlogram_asymmetry: .2f}"
f"\nISI violation: {isi_violation:.5f}"
f"\nburst_pair_amplitude_timing_bool: {burst_pair_amplitude_timing_bool}",
fontsize=12,
)
def plot_autocorrelograms(
cluster_data, max_time_difference=20 / 1000, time_bin_size=1 / 1000
):
for sort_group_id, data in cluster_data["sort_groups"].items():
n_clusters = data["n_clusters"]
fig = plt.figure(figsize=(n_clusters + 2, 1))
width_ratios = np.ones(n_clusters + 1)
width_ratios[0] = 2
gs = fig.add_gridspec(
1, n_clusters + 1, wspace=0.1, width_ratios=width_ratios
)
ax = fig.add_subplot(gs[0])
label = "".join(
(
cluster_data["nwb_file_name"],
"\n",
"interval: ",
cluster_data["sort_interval_name"],
"\n",
f"sort group: {sort_group_id}",
)
)
ax.text(-0.3, 0.3, label, multialignment="center", fontsize=12)
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
for spine in ax.spines.values():
spine.set_visible(False)
for ndx, cluster_num in enumerate(data["correlograms"].keys()):
ax = fig.add_subplot(gs[ndx + 1])
plot_correlogram(
cluster_data,
sort_group_id,
cluster_1=cluster_num,
cluster_2=cluster_num,
max_time_difference=max_time_difference,
time_bin_size=time_bin_size,
ax=ax,
)
plt.show()
def plot_cosine_similarity_distribution(cluster_data, fig_scale=1):
n_clusters = np.array(
[data["n_clusters"] for data in cluster_data["sort_groups"].values()]
)
ind = [np.triu(np.full((count, count), True), 1) for count in n_clusters]
cs_list = [
np.array(
[
[
data["cosine_similarities"][ii][jj]
for jj in data["cosine_similarities"].keys()
]
for ii in data["cosine_similarities"].keys()
]
)
for data in cluster_data["sort_groups"].values()
]
cs = [list(np.ravel(cs[ind[ndx]])[:]) for ndx, cs in enumerate(cs_list)]
cs = np.array(list(itertools.chain(*cs)))
fig = plt.figure(figsize=(fig_scale * 12, fig_scale * 3))
gs = fig.add_gridspec(1, 1)
ax = fig.add_subplot(gs[0])
label = "".join(
(
cluster_data["nwb_file_name"],
"\n",
"interval: ",
cluster_data["sort_interval_name"],
)
)
plt.hist(cs, 240, color="#2196F3")
ax.axvline(x=0, color="#424242", linewidth=fig_scale * 1.0)
# ax.axvline(x=0.9, color='#424242', linewidth=2)
ax.set_xlim([-1, 1])
ax.set_xticks([-1, -0.5, 0, 0.5, 1])
ax.set_xticklabels([-1, -0.5, 0, 0.5, 1], fontsize=fig_scale * 10)
ax.set_yticks([0, 20, 40, 60])
ax.set_yticklabels([0, 20, 40, 60], fontsize=fig_scale * 10)
ax.set_xlabel("Cosine Similarity", fontsize=fig_scale * 12)
ax.set_ylabel("Count", fontsize=fig_scale * 12)
ax.set_title(label, fontsize=fig_scale * 12)
plt.show()
def plot_correlogram_asymmetry_distribution(cluster_data, fig_scale=1):
n_clusters = np.array(
[data["n_clusters"] for data in cluster_data["sort_groups"].values()]
)
ind = [np.triu(np.full((count, count), True), 1) for count in n_clusters]
ca_list = [
np.array(
[
[
data["correlogram_asymmetries"][ii][jj]
for jj in data["correlogram_asymmetries"].keys()
]
for ii in data["correlogram_asymmetries"].keys()
]
)
for data in cluster_data["sort_groups"].values()
]
ca = [list(np.ravel(ca[ind[ndx]])[:]) for ndx, ca in enumerate(ca_list)]
ca = np.array(list(itertools.chain(*ca)))
fig = plt.figure(figsize=(fig_scale * 12, fig_scale * 3))
gs = fig.add_gridspec(1, 1)
ax = fig.add_subplot(gs[0])
label = "".join(
(
cluster_data["nwb_file_name"],
"\n",
"interval: ",
cluster_data["sort_interval_name"],
)
)
plt.hist(ca, 240, color="#2196F3")
ax.set_yscale("log")
ax.set_xlim([0.5, 1])
ax.set_xticks([0.5, 0.6, 0.7, 0.8, 0.9, 1])
ax.set_xticklabels([0.5, 0.6, 0.7, 0.8, 0.9, 1], fontsize=fig_scale * 10)
ax.set_yticks([1, 10, 100, 1000])
ax.set_yticklabels([1, 10, 100, 1000], fontsize=fig_scale * 10)
ax.set_xlabel("Correlogram Asymmetry", fontsize=fig_scale * 12)
ax.set_ylabel("Count", fontsize=fig_scale * 12)
ax.set_title(label, fontsize=fig_scale * 12)
plt.show()
def check_all_unique(x):
if len(np.unique(x)) != len(x):
raise Exception(f"Not all elements unique")
def strip(x, strip_character, strip_start=False, strip_end=True):
if strip_start:
if x[0] == strip_character:
x = x[1:]
if strip_end:
if x[-1] == strip_character:
x = x[:-1]
return x
def df_filter_columns(df, key, column_and=True):
if column_and:
return df[
np.asarray([df[k] == v for k, v in key.items()]).sum(axis=0)
== len(key)
]
else:
return df[
np.asarray([df[k] == v for k, v in key.items()]).sum(axis=0) > 0
]
def df_filter1_columns(df, key, tolerate_no_entry=False):
df_subset = df_filter_columns(df, key)
if np.logical_or(
len(df_subset) > 1, not tolerate_no_entry and len(df_subset) == 0
):
raise Exception(
f"Should have found exactly one entry in df for key, but found {len(df_subset)}"
)
return df_subset
def df_pop(df, key, column, tolerate_no_entry=False):
df_subset = df_filter1_columns(df, key, tolerate_no_entry)
if len(df_subset) == 0: # empty df
return df_subset
return df_subset.iloc[0][column]
def df_filter_columns_isin(df, key):
if len(key) == 0: # if empty key
return df
return df[
np.sum(np.asarray([df[k].isin(v) for k, v in key.items()]), axis=0)
== len(key)
]
# Alternate code: df[df[list(df_filter)].isin(df_filter).all(axis=1)]
def zip_df_columns(df, column_names=None):
if column_names is None:
column_names = df.columns
return zip(*[df[column_name] for column_name in column_names])
def nwbf_name_from_subject_id_date(subject_id, date):
return f"{subject_id}{date}_.nwb"
def subject_id_date_from_nwbf_name(nwb_file_name):
len_date = 8
subject_id_date = nwb_file_name.split("_.nwb")[0]
subject_id = subject_id_date[:-len_date]
date = subject_id_date[-len_date:]
return subject_id, date
def unpack_single_element(x, tolerate_no_entry=False, return_no_entry=None):
if tolerate_no_entry:
if len(x) == 0:
return return_no_entry
return unpack_single_element(x, tolerate_no_entry=False)
if len(x) != 1:
raise Exception(f"len should be one")
return x[0]
def mask_upper_diagonal(arr, mask_value=0):
mask = np.zeros_like(arr, dtype=np.bool)
mask[np.tril_indices_from(mask)] = True
arr[mask] = mask_value
return arr
def cd_make_if_nonexistent(directory):
"""
Change to a directory if it exists. If it does not, make it then change to it.
:param directory: string. Directory to change to.
"""
if not os.path.exists(directory):
print(f"Making directory: {directory}")
os.mkdir(directory)
print(f"Changing to directory: {directory}")
os.chdir(directory)
def single_axis(axes):
return hasattr(axes, "plot")
def get_ax_for_left_right_layout(axes, plot_num):
"""
Return ax from axes if arranging plots left to right, top to bottom
:param axes: array with axis objects
:param plot_num: plot number
:return: current axis given plot number
"""
if single_axis(axes): # single axis object
return axes
if len(np.shape(axes)) == 1: # one row or one column of subplots
return axes[plot_num]
elif len(np.shape(axes)) == 2: # 2D panel of subplots
num_columns = np.shape(axes)[1]
row, col = divmod(
plot_num, num_columns
) # find row/column for current plot
return axes[row, col] # get axis for current plot
else:
raise Exception(f"axes do not conform to expected cases")
def _load_notes(
subject_id,
spreadsheet_name,
recording_spreadsheet_path=None,
header=0,
tolerate_no_notes=False,
):
"""
Load spreadsheet for a given subject
"""
# Get inputs if not passed
if recording_spreadsheet_path is None:
recording_spreadsheet_path = get_recording_spreadsheet_path(subject_id)
# Get file path
file_path = os.path.join(recording_spreadsheet_path, spreadsheet_name)
# If tolerating no notes and no notes, return empty df
if tolerate_no_notes and not os.path.exists(file_path):
return pd.DataFrame()
# Return recording spreadsheet as pandas dataframe
return pd.read_csv(file_path, header=header)
def load_curation_merge_notes(
subject_id, date, recording_spreadsheet_path=None, tolerate_no_notes=False
):
"""
Load recording spreadsheet from saved file for a given subject
"""
return _load_notes(
subject_id,
spreadsheet_name=f"curation_merge - {subject_id}{date}_summary.csv",
recording_spreadsheet_path=recording_spreadsheet_path,
header=[0, 1],
tolerate_no_notes=tolerate_no_notes,
)
# -
# ## calculate metrics for all cell pairs
# +
# Define dataset - need to re-do this spike sorting
subject_ids = [
"tonks",
"tonks",
"tonks",
"tonks",
"tonks",
"tonks",
]
dates = [
"20211107",
"20211108",
"20211109",
"20211110",
"20211111",
"20211112",
]
# for nwb_file_name in ['tonks20211107_.nwb','tonks20211108_.nwb','tonks20211109_.nwb',
# 'tonks20211110_.nwb','tonks20211111_.nwb','tonks20211112_.nwb',]:
# subject_ids = ["ginny","ginny","ginny","ginny","ginny","ginny","ginny",
# "ginny","ginny","ginny","ginny","ginny",]
# dates = ["20211025","20211026","20211027","20211028","20211029","20211030","20211031","20211101",
# "20211102","20211103","20211104","20211105",]
# Make curation data
sort_interval_name = "r2_r3"
preproc_params_name = "franklab_tetrode_hippocampus"
sort_group_ids = all_tet_list # [0] # , 1, 2] # None
# this should be the curation_id correspodning to automatic curation
curation_id = 1
overwrite_existing = False
verbose = True
# need to set make_data to true to generate new metrics
make_data = True
print("start", datetime.datetime.now())
if make_data:
make_curation_data_wrapper(
subject_ids,
dates,
sort_interval_name=sort_interval_name,
preproc_params_name=preproc_params_name,
sort_group_ids=sort_group_ids,
curation_id=curation_id,
overwrite_existing=overwrite_existing,
verbose=verbose,
)
print("end", datetime.datetime.now())
# -
# ## load calculated metrics and find candidates
# +
# try to combine all curation loading steps into one cell
# ginny
all_tet_list = (
np.array(
[
1,
2,
4,
5,
7,
8,
11,
12,
13,
14,
15,
16,
17,
20,
21,
22,
25,
26,
27,
28,
31,
33,
34,
35,
36,
37,
39,
41,
42,
43,
44,
45,
47,
49,
51,
52,
54,
56,
57,
59,
61,
62,
63,
64,
]
)
- 1
)
# Load curation data
subject_ids = [
"ginny",
"ginny",
"ginny",
"ginny",
"ginny",
"ginny",
"ginny",
"ginny",
"ginny",
"ginny",
"ginny",
"ginny",
"ginny",
"ginny",
"ginny",
"ginny",
"ginny",
"ginny",
"ginny",
"ginny",
"ginny",
]
dates = [
"20211023",
"20211024",
"20211025",
"20211026",
"20211027",
"20211028",
"20211029",
"20211030",
"20211031",
"20211101",
"20211102",
"20211103",
"20211104",
"20211105",
"20211106",
"20211108",
"20211109",
"20211110",
"20211111",
"20211112",
"20211113",
]
for date in dates:
subject_ids = ["ginny"]
dates = [date]
# set file name
def make_param_name(param_values):
return "_2".join([str(x) for x in param_values])
# Make curation data
sort_interval_name = "r2_r3"
preproc_params_name = "franklab_tetrode_hippocampus"
sort_group_ids = all_tet_list # [0] # , 1, 2] # None
# this should be the curation_id correspodning to automatic curation
curation_id = 1
overwrite_quantities = False
use_old = False
target_region = None # only use if have saved out data with target region name using the code commented out below
use_target_region = False
verbose = True
file_path_base = "/cumulus/mcoulter/curation_data/" # PUT PATH WHERE YOU WANT TO SAVE FILES HERE
cluster_data_container = dict()
for subject_id, date in zip(subject_ids, dates):
# Define directory to save data in
save_dir = f"{file_path_base}/{subject_id}/" + "old" * use_old
# Get nwb file name
nwb_file_name = nwbf_name_from_subject_id_date(subject_id, date)
cluster_data_container[nwb_file_name] = load_curation_data(
save_dir=save_dir,
nwb_file_name=nwb_file_name,
sort_interval_name=sort_interval_name,
preproc_params_name=preproc_params_name,
sort_group_ids=sort_group_ids,
target_region=target_region,
curation_id=curation_id,
overwrite_quantities=overwrite_quantities,
verbose=verbose,
)
# Get curation merge notes
curation_merge_notes_map = {
nwbf_name_from_subject_id_date(
subject_id, date
): get_curation_spreadsheet(subject_id, date)
for subject_id, date in zip(subject_ids, dates)
}
# Make df with metrics and merge label for all units
# Note that this should be done separately for merge types (e.g. burst pair and split cell),
# so that each cell falls into only one category: "true" merge pair, "false" merge pair, or unlabeled
# Get metrics for these merge pairs
merge_types = ["burst_pair", "split_cell"]
merge_pair_identifiers = ["sort_group_id", "unit_id_1", "unit_id_2"]
labels = [True, False]
merge_type_metric_names_map = {
"burst_pair": [
"cosine_similarities",
"correlogram_asymmetries",
"unit_pair_percent_isi_violations",
"correlogram_isi_violation_ratios",
"correlogram_counts",
"burst_pair_amplitude_timing_bools",
"unit_merge_valid_lower_amplitude_fractions",
"unit_merge_amplitude_decrements_0.015",
"unit_merge_amplitude_decrements_0.4",
"amplitude_decrement_changes_0.015",
"amplitude_decrement_changes_0.4",
],
"split_cell": [
"cosine_similarities",
"amplitude_overlaps",
"unit_pair_percent_isi_violations",
"correlogram_isi_violation_ratios",
"correlogram_counts",
],
}
column_names = [
"nwb_file_name",
"metric_name",
"label",
"merge_tuple",
"metric_value",
]
# Get metrics for these nwb file names
nwb_file_names = list(cluster_data_container.keys())
# Check that nwb file names well defined
if not all([x in cluster_data_container.keys() for x in nwb_file_names]):
raise Exception(
f"Can only make metrics for nwb file names in cluster_data_container, which are: {cluster_data_container.keys()}"
)
# Get units with "true" and "false" human labels
merge_df_map = {
merge_type: pd.DataFrame(columns=column_names)
for merge_type in merge_types
} # initialize
for merge_type_idx, merge_type in enumerate(merge_types):
metric_names = merge_type_metric_names_map[merge_type]
data_list = []
for nwb_file_name in nwb_file_names:
cluster_data = cluster_data_container[nwb_file_name] # cluster data
curation_merge_notes = curation_merge_notes_map[
nwb_file_name
] # notes with "true" and "false" human labels
# Continue if curation_merge_notes empty
if len(curation_merge_notes) == 0:
continue
# Otherwise, get metrics for units labeled as "true" or "false" merge pairs from curation_merge_notes
for metric_name in metric_names:
for label in labels:
# Define merge tuples as those present in curation notes
merge_tuples = list(
zip_df_columns(
df_filter_columns(
curation_merge_notes_map[nwb_file_name],
{"merge_type": merge_type, "label": label},
),
merge_pair_identifiers,
)
)
# Remove merge tuples with sort group not in cluster data
merge_tuples = [
x
for x in merge_tuples
if x[0] in cluster_data["sort_groups"]
]
for merge_tuple in merge_tuples:
sort_group_id, unit_1, unit_2 = merge_tuple
# Continue if metric not calculated for current sort group
if (
metric_name
not in cluster_data["sort_groups"][sort_group_id]
):
continue
metric_value = cluster_data["sort_groups"][
sort_group_id
][metric_name][unit_1][unit_2]
data_list.append(
(
nwb_file_name,
metric_name,
label,
merge_tuple,
metric_value,
)
)
# Only update if data, otherwise overwrites column names with nothing
if len(data_list) > 0:
merge_df_map[merge_type] = pd.DataFrame.from_dict(
{k: v for k, v in zip(column_names, list(zip(*data_list)))}
)
# Get unlabeled units
for merge_type, merge_df in merge_df_map.items():
metric_names = merge_type_metric_names_map[merge_type]
data_list = []
for nwb_file_name in nwb_file_names:
# ...Get all unit pairs for this nwb file (limited to cluster data files that have been created)
sort_group_id_unit_pair_map = {
sort_group_id: list(
itertools.combinations(data["unit_ids"], r=2)
)
for sort_group_id, data in cluster_data["sort_groups"].items()
}
merge_tuples = [
(sort_group_id, unit_1, unit_2)
for sort_group_id, unit_pairs in sort_group_id_unit_pair_map.items()
for (unit_1, unit_2) in unit_pairs
]
# ...Get labeled units for this nwb file
merge_df_subset = df_filter_columns(
merge_df, {"nwb_file_name": nwb_file_name}
)
labeled_merge_tuples = merge_df_subset["merge_tuple"]
merge_tuples = set(merge_tuples) - set(
labeled_merge_tuples
) # unit pairs that were not labeled as true or false merge pairs
for merge_tuple in merge_tuples:
sort_group_id, unit_1, unit_2 = merge_tuple
for metric_name in metric_names:
# Continue if metric not calculated for current sort group
if (
metric_name
not in cluster_data["sort_groups"][sort_group_id]
):
continue
metric_value = cluster_data["sort_groups"][sort_group_id][
metric_name
][unit_1][unit_2]
data_list.append(
(
nwb_file_name,
metric_name,
"none",
merge_tuple,
metric_value,
)
)
# Update merge df
other_df = pd.DataFrame.from_dict(
{k: v for k, v in zip(column_names, list(zip(*data_list)))}
)
merge_df_map[merge_type] = pd.concat((merge_df, other_df))
# re-run with new thresholds
# decided to raise amp overlap to 0.5 after checking a few days
Threshold = namedtuple(
"threshold", "metric_name threshold_value threshold_direction"
)
Thresholds = namedtuple("thresholds", "name thresholds color lw")
threshold_sets = [
(
"burst_pair",
[
("cosine_similarities", 0.7, operator.gt),
("correlogram_asymmetries", 0.6, operator.gt),
("correlogram_counts", 100, operator.gt),
("unit_pair_percent_isi_violations", 0.25, operator.lt),
("burst_pair_amplitude_timing_bools", 0, operator.gt),
],
"#2196F3",
9,
),
(
"split_cell",
[
("cosine_similarities", 0.5, operator.gt),
("correlogram_counts", 100, operator.gt),
("unit_pair_percent_isi_violations", 0.25, operator.lt),
("amplitude_overlaps", 0.5, operator.gt),
],
"limegreen",
3,
),
# ("test",
# [("cosine_similarities", 0, operator.gt)],
# "orange",
# 5)
]
threshold_sets = {
name: Thresholds(name, [Threshold(*x) for x in thresholds], color, lw)
for name, thresholds, color, lw in threshold_sets
}
# print merge candidates
merge_count = 0
cluster_data = cluster_data_container[nwb_file_name]
tuple_list = []
for sort_group_id, data in cluster_data["sort_groups"].items():
# print('merge candidate',sort_group_id,'_',unit_1,'_vs._',
# sort_group_id,'_',unit_2,'_')
valid_bool_map = get_above_threshold_matrix_indices(
cluster_data, sort_group_id, threshold_sets
)
for (
threshold_name,
valid_bool,
) in valid_bool_map.items(): # threshold sets
# Find indices in array corresponding to merge candidates
merge_candidate_idxs = list(zip(*np.where(valid_bool)))
# Convert merge candidate indices in array to unit IDs
merge_candidates = [
tuple(np.asarray(data["unit_ids"])[np.asarray(idxs)])
for idxs in merge_candidate_idxs
]
# Loop through merge candidates and plot metrics
for unit_1, unit_2 in merge_candidates: # units
# print(nwb_file_name,': merge candidate for tetrode',sort_group_id,', clusters',unit_1,'and',
# unit_2,'are a',threshold_name, '(',sort_group_id,unit_1,unit_2,')')
# print(tuple([sort_group_id,unit_1,unit_2]))
# create list of merge candidates
tuple_list.append(tuple([sort_group_id, unit_1, unit_2]))
merge_count += 1
print("merge count", merge_count)
print([dates])
print(len(tuple_list))
print(tuple_list)
# -
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment