Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor quality metrics tests to use fixture #3249

Merged
merged 10 commits into from
Sep 12, 2024
51 changes: 51 additions & 0 deletions src/spikeinterface/qualitymetrics/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
create_sorting_analyzer,
)

job_kwargs = dict(n_jobs=-1, progress_bar=True, chunk_duration="1s")


def _small_sorting_analyzer():
recording, sorting = generate_ground_truth_recording(
Expand Down Expand Up @@ -35,3 +37,52 @@ def _small_sorting_analyzer():
@pytest.fixture(scope="module")
def small_sorting_analyzer():
return _small_sorting_analyzer()


def _sorting_analyzer_simple():
chrishalcrow marked this conversation as resolved.
Show resolved Hide resolved

# we need high firing rate for amplitude_cutoff
recording, sorting = generate_ground_truth_recording(
durations=[
120.0,
],
sampling_frequency=30_000.0,
num_channels=6,
num_units=10,
generate_sorting_kwargs=dict(firing_rates=10.0, refractory_period_ms=4.0),
generate_unit_locations_kwargs=dict(
margin_um=5.0,
minimum_z=5.0,
maximum_z=20.0,
),
generate_templates_kwargs=dict(
unit_params=dict(
alpha=(200.0, 500.0),
)
),
noise_kwargs=dict(noise_levels=5.0, strategy="tile_pregenerated"),
seed=1205,
)

sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory", sparse=True)

sorting_analyzer.compute("random_spikes", max_spikes_per_unit=300, seed=1205)
sorting_analyzer.compute("noise_levels")
sorting_analyzer.compute("waveforms", **job_kwargs)
sorting_analyzer.compute("templates")
sorting_analyzer.compute("spike_amplitudes", **job_kwargs)

return sorting_analyzer


@pytest.fixture(scope="module")
chrishalcrow marked this conversation as resolved.
Show resolved Hide resolved
def sorting_analyzer_simple():
sorting_analyzer = get_sorting_analyzer(seed=2205)
return sorting_analyzer

return sorting_analyzer
chrishalcrow marked this conversation as resolved.
Show resolved Hide resolved


@pytest.fixture(scope="module")
JoeZiminski marked this conversation as resolved.
Show resolved Hide resolved
def sorting_analyzer_simple():
return _sorting_analyzer_simple()
85 changes: 32 additions & 53 deletions src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,36 +135,6 @@ def test_unit_id_order_independence(small_sorting_analyzer):
assert quality_metrics_2[metric][1] == metric_1_data["#4"]


def _sorting_analyzer_simple():
recording, sorting = generate_ground_truth_recording(
durations=[
50.0,
],
sampling_frequency=30_000.0,
num_channels=6,
num_units=10,
generate_sorting_kwargs=dict(firing_rates=6.0, refractory_period_ms=4.0),
noise_kwargs=dict(noise_levels=5.0, strategy="tile_pregenerated"),
seed=2205,
)

sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory", sparse=True)

sorting_analyzer.compute("random_spikes", max_spikes_per_unit=300, seed=2205)
sorting_analyzer.compute("noise_levels")
sorting_analyzer.compute("waveforms", **job_kwargs)
sorting_analyzer.compute("templates")
sorting_analyzer.compute("principal_components", n_components=5, mode="by_channel_local", **job_kwargs)
sorting_analyzer.compute("spike_amplitudes", **job_kwargs)

return sorting_analyzer


@pytest.fixture(scope="module")
def sorting_analyzer_simple():
return _sorting_analyzer_simple()


def _sorting_violation():
max_time = 100.0
sampling_frequency = 30000
Expand Down Expand Up @@ -570,27 +540,36 @@ def test_calculate_sd_ratio(sorting_analyzer_simple):

if __name__ == "__main__":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree I think this is essential for certain Windows run. But it is super inconsistent which Windows. (for example I haven't needed to do if name == main yet, but others have.... So weird.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no way! 🤯 so it depends on the Windows version?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is less "version" and more like version + sub-version + python version + CPU-type + IDE. Who knows what magic determines this.... But one of my co-workers has to do the if name == main, but I don't on my workstation.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had responded but somehow dissapeared, "Haha lol the mysteries of multiprocessing, very interesting!"


sorting_analyzer = _sorting_analyzer_simple()
print(sorting_analyzer)

test_unit_structure_in_output(_small_sorting_analyzer())

# test_calculate_firing_rate_num_spikes(sorting_analyzer)
chrishalcrow marked this conversation as resolved.
Show resolved Hide resolved
# test_calculate_snrs(sorting_analyzer)
# test_calculate_amplitude_cutoff(sorting_analyzer)
# test_calculate_presence_ratio(sorting_analyzer)
# test_calculate_amplitude_median(sorting_analyzer)
# test_calculate_sliding_rp_violations(sorting_analyzer)
# test_calculate_drift_metrics(sorting_analyzer)
# test_synchrony_metrics(sorting_analyzer)
# test_synchrony_metrics_unit_id_subset(sorting_analyzer)
# test_synchrony_metrics_no_unit_ids(sorting_analyzer)
# test_calculate_firing_range(sorting_analyzer)
# test_calculate_amplitude_cv_metrics(sorting_analyzer)
# test_calculate_sd_ratio(sorting_analyzer)

# sorting_analyzer_violations = _sorting_analyzer_violations()
test_unit_structure_in_output(small_sorting_analyzer)
test_unit_id_order_independence(small_sorting_analyzer)

test_synchrony_counts_no_sync()
test_synchrony_counts_one_sync()
test_synchrony_counts_one_quad_sync()
test_synchrony_counts_not_all_units()

test_mahalanobis_metrics()
test_lda_metrics()
test_nearest_neighbors_metrics()
test_silhouette_score_metrics()
test_simplified_silhouette_score_metrics()

test_calculate_firing_rate_num_spikes(sorting_analyzer_simple)
test_calculate_snrs(sorting_analyzer)
test_calculate_amplitude_cutoff(sorting_analyzer)
test_calculate_presence_ratio(sorting_analyzer)
test_calculate_amplitude_median(sorting_analyzer)
test_calculate_sliding_rp_violations(sorting_analyzer)
test_calculate_drift_metrics(sorting_analyzer)
test_synchrony_metrics(sorting_analyzer)
test_synchrony_metrics_unit_id_subset(sorting_analyzer)
test_synchrony_metrics_no_unit_ids(sorting_analyzer)
test_calculate_firing_range(sorting_analyzer)
test_calculate_amplitude_cv_metrics(sorting_analyzer)
test_calculate_sd_ratio(sorting_analyzer)

sorting_analyzer_violations = _sorting_analyzer_violations()
# print(sorting_analyzer_violations)
# test_calculate_isi_violations(sorting_analyzer_violations)
# test_calculate_sliding_rp_violations(sorting_analyzer_violations)
# test_calculate_rp_violations(sorting_analyzer_violations)
test_calculate_isi_violations(sorting_analyzer_violations)
test_calculate_sliding_rp_violations(sorting_analyzer_violations)
test_calculate_rp_violations(sorting_analyzer_violations)
5 changes: 5 additions & 0 deletions src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,8 @@ def test_calculate_pc_metrics(small_sorting_analyzer):
assert not np.all(np.isnan(res2[metric_name].values))

assert np.array_equal(res1[metric_name].values, res2[metric_name].values)


if __name__ == "__main__":

test_calculate_pc_metrics(small_sorting_analyzer)
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from pathlib import Path
import numpy as np


from spikeinterface.core import (
generate_ground_truth_recording,
create_sorting_analyzer,
Expand All @@ -15,51 +14,9 @@
compute_quality_metrics,
)


job_kwargs = dict(n_jobs=2, progress_bar=True, chunk_duration="1s")


def get_sorting_analyzer(seed=2205):
# we need high firing rate for amplitude_cutoff
recording, sorting = generate_ground_truth_recording(
durations=[
120.0,
],
sampling_frequency=30_000.0,
num_channels=6,
num_units=10,
generate_sorting_kwargs=dict(firing_rates=10.0, refractory_period_ms=4.0),
generate_unit_locations_kwargs=dict(
margin_um=5.0,
minimum_z=5.0,
maximum_z=20.0,
),
generate_templates_kwargs=dict(
unit_params=dict(
alpha=(200.0, 500.0),
)
),
noise_kwargs=dict(noise_levels=5.0, strategy="tile_pregenerated"),
seed=seed,
)

sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory", sparse=True)

sorting_analyzer.compute("random_spikes", max_spikes_per_unit=300, seed=seed)
sorting_analyzer.compute("noise_levels")
sorting_analyzer.compute("waveforms", **job_kwargs)
sorting_analyzer.compute("templates")
sorting_analyzer.compute("spike_amplitudes", **job_kwargs)

return sorting_analyzer


@pytest.fixture(scope="module")
def sorting_analyzer_simple():
sorting_analyzer = get_sorting_analyzer(seed=2205)
return sorting_analyzer


def test_compute_quality_metrics(sorting_analyzer_simple):
sorting_analyzer = sorting_analyzer_simple
print(sorting_analyzer)
Expand Down Expand Up @@ -291,9 +248,6 @@ def test_empty_units(sorting_analyzer_simple):

if __name__ == "__main__":

sorting_analyzer = get_sorting_analyzer()
print(sorting_analyzer)

test_compute_quality_metrics(sorting_analyzer)
test_compute_quality_metrics_recordingless(sorting_analyzer)
test_empty_units(sorting_analyzer)
test_compute_quality_metrics(sorting_analyzer_simple)
test_compute_quality_metrics_recordingless(sorting_analyzer_simple)
test_empty_units(sorting_analyzer_simple)