Skip to content

Commit

Permalink
Merge pull request #741 from MouseLand/jacob/output_docs
Browse files Browse the repository at this point in the history
Jacob/output docs
  • Loading branch information
jacobpennington committed Aug 3, 2024
2 parents b82e562 + 5c2178e commit 3f3e28d
Show file tree
Hide file tree
Showing 2 changed files with 230 additions and 20 deletions.
123 changes: 123 additions & 0 deletions kilosort/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,129 @@ def remove_bad_channels(probe, bad_channels):
def save_to_phy(st, clu, tF, Wall, probe, ops, imin, results_dir=None,
data_dtype=None, save_extra_vars=False,
save_preprocessed_copy=False):
"""Save sorting results to disk in a format readable by Phy.
Parameters
----------
st : np.ndarray
3-column array of peak time (in samples), template, and amplitude for
each spike.
clu : np.ndarray
1D vector of cluster ids indicating which spike came from which cluster,
same shape as `st[:,0]`.
tF : torch.Tensor
PC features for each spike, with shape
(n_spikes, nearest_chans, n_pcs)
Wall : torch.Tensor
PC feature representation of spike waveforms for each cluster, with shape
(n_clusters, n_channels, n_pcs).
probe : dict; optional.
A Kilosort4 probe dictionary, as returned by `kilosort.io.load_probe`.
ops : dict
Dictionary storing settings and results for all algorithmic steps.
imin : int
Minimum sample index used by BinaryRWFile, exported spike times will
be shifted forward by this number.
results_dir : pathlib.Path; optional.
Directory where results should be saved.
data_dtype : str or type; optional.
dtype of data in binary file, like `'int32'` or `np.uint16`. By default,
dtype is assumed to be `'int16'`.
save_extra_vars : bool; default=False.
If True, save tF and Wall to disk along with copies of st, clu and
amplitudes with no postprocessing applied.
save_preprocessed_copy : bool; default=False.
If True, save a pre-processed copy of the data (including drift
correction) to `temp_wh.dat` in the results directory and format Phy
output to use that copy of the data.
Returns
-------
results_dir : pathlib.Path.
Directory where results are saved.
similar_templates : np.ndarray.
Similarity score between each pair of clusters, computed as correlation
between clusters. Shape (n_clusters, n_clusters).
is_ref : np.ndarray.
1D boolean array with shape (n_clusters,) indicating whether each
cluster is refractory.
est_contam_rate : np.ndarray.
Contamination rate for each cluster, computed as fraction of refractory
period violations relative to expectation based on a Poisson process.
Shape (n_clusters,).
kept_spikes : np.ndarray.
Boolean mask with shape (n_spikes,) that is False for spikes that were
removed by `kilosort.postprocessing.remove_duplicate_spikes`
and True otherwise.
Notes
-----
The following files will be saved in `results_dir`. Note that 'template'
here does *not* refer to the universal or learned templates used for spike
detection, as it did in some past versions of Kilosort. Instead, it refers
to the average spike waveform (after whitening, filtering, and drift
correction) for all spikes assigned to each cluster, which are template-like
in shape. We use the term 'template' anyway for this section because that is
how they are treated in Phy. Elsewhere in the Kilosort4 code, we would refer
to these as 'clusters.'
amplitudes.npy : shape (n_spikes,)
Per-spike amplitudes, computed as the L2 norm of the PC features
for each spike.
channel_map.npy : shape (n_channels,)
Same as probe['chanMap']. Integer indices into rows of binary file
that map the data to the contacts listed in the probe file.
channel_positions.npy : shape (n_channels,2)
Same as probe['xc'] and probe['yc'], but combined in a single array.
Indicates x- and y- positions (in microns) of probe contacts.
cluster_Amplitude.tsv : shape (n_templates,)
Per-template amplitudes, computed as the L2 norm of the template.
cluster_ContamPct.tsv : shape (n_templates,)
Contamination rate for each template, computed as fraction of refractory
period violations relative to expectation based on a Poisson process.
cluster_KSLabel.tsv : shape (n_templates,)
Label indicating whether each template is 'mua' (multi-unit activity)
or 'good' (refractory).
cluster_group.tsv : shape (n_templates,)
Same as `cluster_KSLabel.tsv`.
kept_spikes.npy : shape (n_spikes,)
Boolean mask that is False for spikes that were removed by
`kilosort.postprocessing.remove_duplicate_spikes` and True otherwise.
ops.npy : shape N/A
Dictionary containing a number of state variables saved throughout
the sorting process (see `run_kilosort`). We recommend loading with
`kilosort.io.load_ops`.
params.py : shape N/A
Settings used by Phy, like data location and sampling rate.
similar_templates.npy : shape (n_templates, n_templates)
Similarity score between each pair of templates, computed as correlation
between templates.
spike_clusters.npy : shape (n_spikes,)
For each spike, integer indicating which template it was assigned to.
spike_templates.npy : shape (n_spikes,2)
Same as `spike_clusters.npy`.
spike_positions.npy : shape (n_spikes,2)
Estimated (x,y) position relative to probe geometry, in microns,
for each spike.
spike_times.npy : shape (n_spikes,)
Sample index of the waveform peak for each spike.
templates.npy : shape (n_templates, nt, n_channels)
Full time x channels template shapes.
templates_ind.npy : shape (n_templates, n_channels)
Channel indices on which each cluster is defined. For KS4, this is always
all channels, but Phy requires this file.
whitening_mat.npy : shape (n_channels, n_channels)
Matrix applied to data for whitening.
whitening_mat_inv.npy : shape (n_channels, n_channels)
Inverse of whitening matrix.
whitening_mat_dat.npy : shape (n_channels, n_channels)
matrix applied to data for whitening. Currently this is the same as
`whitening_mat.npy`, but was added because the latter was previously
altered before saving for Phy, so this ensured the original was still
saved. It's kept in for now because we may need to change the version
used by Phy again in the future.
"""

if results_dir is None:
results_dir = ops['data_dir'].joinpath('kilosort4')
Expand Down
127 changes: 107 additions & 20 deletions kilosort/run_kilosort.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,38 @@ def run_kilosort(settings, probe=None, probe_name=None, filename=None,
Returns
-------
ops, st, clu, tF, Wall, similar_templates, is_ref, est_contam_rate, kept_spikes
Description TODO
ops : dict
Dictionary storing settings and results for all algorithmic steps.
st : np.ndarray
3-column array of peak time (in samples), template, and amplitude for
each spike.
clu : np.ndarray
1D vector of cluster ids indicating which spike came from which cluster,
same shape as `st[:,0]`.
tF : torch.Tensor
PC features for each spike, with shape
(n_spikes, nearest_chans, n_pcs)
Wall : torch.Tensor
PC feature representation of spike waveforms for each cluster, with shape
(n_clusters, n_channels, n_pcs).
similar_templates : np.ndarray.
Similarity score between each pair of clusters, computed as correlation
between clusters. Shape (n_clusters, n_clusters).
is_ref : np.ndarray.
1D boolean array with shape (n_clusters,) indicating whether each
cluster is refractory.
est_contam_rate : np.ndarray.
Contamination rate for each cluster, computed as fraction of refractory
period violations relative to expectation based on a Poisson process.
Shape (n_clusters,).
kept_spikes : np.ndarray.
Boolean mask with shape (n_spikes,) that is False for spikes that were
removed by `kilosort.postprocessing.remove_duplicate_spikes`
and True otherwise.
Notes
-----
For documentation of saved files, see `kilosort.io.save_to_phy`.
"""

Expand Down Expand Up @@ -447,8 +477,12 @@ def compute_drift_correction(ops, device, tic0=np.nan, progress_bar=None,
Returns
-------
ops : dict
Dictionary storing settings and results for all algorithmic steps.
bfile : kilosort.io.BinaryFiltered
Wrapped file object for handling data.
st0 : np.ndarray.
Intermediate spike times variable with 6 columns. This is only used
for generating the 'Drift Scatter' plot through the GUI.
"""

Expand Down Expand Up @@ -493,7 +527,7 @@ def compute_drift_correction(ops, device, tic0=np.nan, progress_bar=None,


def detect_spikes(ops, device, bfile, tic0=np.nan, progress_bar=None):
"""Run spike sorting algorithm and save intermediate results to `ops`.
"""Detect spikes via template deconvolution.
Parameters
----------
Expand All @@ -511,14 +545,17 @@ def detect_spikes(ops, device, bfile, tic0=np.nan, progress_bar=None):
Returns
-------
st : np.ndarray
1D vector of spike times for all clusters.
3-column array of peak time (in samples), template, and amplitude for
each spike.
clu : np.ndarray
1D vector of cluster ids indicating which spike came from which cluster,
same shape as `st`.
tF : np.ndarray
TODO
Wall : np.ndarray
TODO
tF : torch.Tensor
PC features for each spike, with shape
(n_spikes, nearest_chans, n_pcs)
Wall : torch.Tensor
PC feature representation of spike waveforms for each cluster, with shape
(n_clusters, n_channels, n_pcs).
"""

Expand Down Expand Up @@ -564,6 +601,37 @@ def detect_spikes(ops, device, bfile, tic0=np.nan, progress_bar=None):


def cluster_spikes(st, tF, ops, device, bfile, tic0=np.nan, progress_bar=None):
"""Cluster spikes using graph-based methods.
Parameters
----------
st : np.ndarray
3-column array of peak time (in samples), template, and amplitude for
each spike.
tF : torch.Tensor
PC features for each spike, with shape
(n_spikes, nearest_chans, n_pcs)
ops : dict
Dictionary storing settings and results for all algorithmic steps.
device : torch.device
Indicates whether `pytorch` operations should be run on cpu or gpu.
bfile : kilosort.io.BinaryFiltered
Wrapped file object for handling data.
tic0 : float; default=np.nan.
Start time of `run_kilosort`.
progress_bar : TODO; optional.
Informs `tqdm` package how to report progress, type unclear.
Returns
-------
clu : np.ndarray
1D vector of cluster ids indicating which spike came from which cluster,
same shape as `st`.
Wall : torch.Tensor
PC feature representation of spike waveforms for each cluster, with shape
(n_clusters, n_channels, n_pcs).
"""
tic = time.time()
logger.info(' ')
logger.info('Final clustering')
Expand Down Expand Up @@ -603,21 +671,25 @@ def save_sorting(ops, results_dir, st, clu, tF, Wall, imin, tic0=np.nan,
results_dir : pathlib.Path
Directory where results should be saved.
st : np.ndarray
1D vector of spike times for all clusters.
3-column array of peak time (in samples), template, and amplitude for
each spike.
clu : np.ndarray
1D vector of cluster ids indicating which spike came from which cluster,
same shape as `st`.
tF : np.ndarray
TODO
Wall : np.ndarray
TODO
same shape as `st[:,0]`.
tF : torch.Tensor
PC features for each spike, with shape
(n_spikes, nearest_chans, n_pcs)
Wall : torch.Tensor
PC feature representation of spike waveforms for each cluster, with shape
(n_clusters, n_channels, n_pcs).
imin : int
Minimum sample index used by BinaryRWFile, exported spike times will
be shifted forward by this number.
tic0 : float; default=np.nan.
Start time of `run_kilosort`.
save_extra_vars : bool; default=False.
If True, save tF and Wall to disk after sorting.
If True, save tF and Wall to disk along with copies of st, clu and
amplitudes with no postprocessing applied.
save_preprocessed_copy : bool; default=False.
If True, save a pre-processed copy of the data (including drift
correction) to `temp_wh.dat` in the results directory and format Phy
Expand All @@ -626,11 +698,26 @@ def save_sorting(ops, results_dir, st, clu, tF, Wall, imin, tic0=np.nan,
Returns
-------
ops : dict
similar_templates : np.ndarray
is_ref : np.ndarray
est_contam_rate : np.ndarray
kept_spikes : np.ndarray
Dictionary storing settings and results for all algorithmic steps.
similar_templates : np.ndarray.
Similarity score between each pair of clusters, computed as correlation
between clusters. Shape (n_clusters, n_clusters).
is_ref : np.ndarray.
1D boolean array with shape (n_clusters,) indicating whether each
cluster is refractory.
est_contam_rate : np.ndarray.
Contamination rate for each cluster, computed as fraction of refractory
period violations relative to expectation based on a Poisson process.
Shape (n_clusters,).
kept_spikes : np.ndarray.
Boolean mask with shape (n_spikes,) that is False for spikes that were
removed by `kilosort.postprocessing.remove_duplicate_spikes`
and True otherwise.
Notes
-----
For documentation of saved files, see `kilosort.io.save_to_phy`.
"""

logger.info(' ')
Expand Down

0 comments on commit 3f3e28d

Please sign in to comment.