Skip to content

Commit

Permalink
MRG: Bump colorbar control points (mne-tools#7188)
Browse files Browse the repository at this point in the history
* Add SnapColorbarPoints helper class

* Use BumpColorbarPoints helper class

* Update tests

* Rollback event_type to default

* Set new range for colorbar sliders

* ENH: Better colorbar logic

* Improve coverage

* Improve coverage

* Tweak the callback refresh rate

Co-authored-by: Eric Larson <larson.eric.d@gmail.com>
  • Loading branch information
2 people authored and AdoNunes committed Apr 6, 2020
1 parent 2653d4c commit 25bf5db
Show file tree
Hide file tree
Showing 4 changed files with 295 additions and 112 deletions.
115 changes: 58 additions & 57 deletions mne/viz/_brain/_brain.py
Original file line number Diff line number Diff line change
Expand Up @@ -793,14 +793,18 @@ def update_lut(self, fmin=None, fmid=None, fmax=None):
center = self._data['center']
colormap = self._data['colormap']
transparent = self._data['transparent']
fmin = self._data['fmin'] if fmin is None else fmin
fmid = self._data['fmid'] if fmid is None else fmid
fmax = self._data['fmax'] if fmax is None else fmax

lims = dict(fmin=fmin, fmid=fmid, fmax=fmax)
lims = {key: self._data[key] if val is None else val
for key, val in lims.items()}
assert all(val is not None for val in lims.values())
if lims['fmin'] > lims['fmid']:
lims['fmin'] = lims['fmid']
if lims['fmax'] < lims['fmid']:
lims['fmax'] = lims['fmid']
self._data.update(lims)
self._data['ctable'] = \
calculate_lut(colormap, alpha=alpha, fmin=fmin, fmid=fmid,
fmax=fmax, center=center, transparent=transparent)

calculate_lut(colormap, alpha=alpha, center=center,
transparent=transparent, **lims)
return self._data['ctable']

def set_data_smoothing(self, n_steps):
Expand Down Expand Up @@ -863,64 +867,61 @@ def set_time_point(self, time_idx):
def update_fmax(self, fmax):
"""Set the colorbar max point."""
from ..backends._pyvista import _set_colormap_range
if fmax > self._data['fmid']:
ctable = self.update_lut(fmax=fmax)
ctable = (ctable * 255).astype(np.uint8)
center = self._data['center']
for hemi in ['lh', 'rh']:
actor = self._data.get(hemi + '_actor')
if actor is not None:
fmin = self._data['fmin']
center = self._data['center']
dt_max = fmax
dt_min = fmin if center is None else -1 * fmax
rng = [dt_min, dt_max]
if self._colorbar_added:
scalar_bar = self._renderer.plotter.scalar_bar
else:
scalar_bar = None
_set_colormap_range(actor, ctable, scalar_bar, rng)
self._data['fmax'] = fmax
self._data['ctable'] = ctable
ctable = self.update_lut(fmax=fmax)
ctable = (ctable * 255).astype(np.uint8)
center = self._data['center']
for hemi in ['lh', 'rh']:
actor = self._data.get(hemi + '_actor')
if actor is not None:
fmin = self._data['fmin']
center = self._data['center']
dt_max = fmax
dt_min = fmin if center is None else -1 * fmax
rng = [dt_min, dt_max]
if self._colorbar_added:
scalar_bar = self._renderer.plotter.scalar_bar
else:
scalar_bar = None
_set_colormap_range(actor, ctable, scalar_bar, rng)
self._data['fmax'] = fmax
self._data['ctable'] = ctable

def update_fmid(self, fmid):
"""Set the colorbar mid point."""
from ..backends._pyvista import _set_colormap_range
if self._data['fmin'] < fmid < self._data['fmax']:
ctable = self.update_lut(fmid=fmid)
ctable = (ctable * 255).astype(np.uint8)
for hemi in ['lh', 'rh']:
actor = self._data.get(hemi + '_actor')
if actor is not None:
if self._colorbar_added:
scalar_bar = self._renderer.plotter.scalar_bar
else:
scalar_bar = None
_set_colormap_range(actor, ctable, scalar_bar)
self._data['fmid'] = fmid
self._data['ctable'] = ctable
ctable = self.update_lut(fmid=fmid)
ctable = (ctable * 255).astype(np.uint8)
for hemi in ['lh', 'rh']:
actor = self._data.get(hemi + '_actor')
if actor is not None:
if self._colorbar_added:
scalar_bar = self._renderer.plotter.scalar_bar
else:
scalar_bar = None
_set_colormap_range(actor, ctable, scalar_bar)
self._data['fmid'] = fmid
self._data['ctable'] = ctable

def update_fmin(self, fmin):
"""Set the colorbar min point."""
from ..backends._pyvista import _set_colormap_range
if fmin < self._data['fmid']:
ctable = self.update_lut(fmin=fmin)
ctable = (ctable * 255).astype(np.uint8)
for hemi in ['lh', 'rh']:
actor = self._data.get(hemi + '_actor')
if actor is not None:
fmax = self._data['fmax']
center = self._data['center']
dt_max = fmax
dt_min = fmin if center is None else -1 * fmax
rng = [dt_min, dt_max]
if self._colorbar_added:
scalar_bar = self._renderer.plotter.scalar_bar
else:
scalar_bar = None
_set_colormap_range(actor, ctable, scalar_bar, rng)
self._data['fmin'] = fmin
self._data['ctable'] = ctable
ctable = self.update_lut(fmin=fmin)
ctable = (ctable * 255).astype(np.uint8)
for hemi in ['lh', 'rh']:
actor = self._data.get(hemi + '_actor')
if actor is not None:
fmax = self._data['fmax']
center = self._data['center']
dt_max = fmax
dt_min = fmin if center is None else -1 * fmax
rng = [dt_min, dt_max]
if self._colorbar_added:
scalar_bar = self._renderer.plotter.scalar_bar
else:
scalar_bar = None
_set_colormap_range(actor, ctable, scalar_bar, rng)
self._data['fmin'] = fmin
self._data['ctable'] = ctable

def update_fscale(self, fscale):
"""Scale the colorbar points."""
Expand Down
106 changes: 90 additions & 16 deletions mne/viz/_brain/_timeviewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
#
# License: Simplified BSD

import time
import numpy as np


class IntSlider(object):
"""Class to set a integer slider."""
Expand All @@ -13,9 +16,9 @@ def __init__(self, plotter=None, callback=None, name=None):
self.callback = callback
self.name = name

def __call__(self, idx):
def __call__(self, value):
"""Round the label of the slider."""
idx = int(round(idx))
idx = int(round(value))
for slider in self.plotter.slider_widgets:
name = getattr(slider, "name", None)
if name == self.name:
Expand Down Expand Up @@ -50,6 +53,58 @@ def __call__(self, value):
slider_rep.SetValue(fmax)


class BumpColorbarPoints(object):
"""Class that ensure constraints over the colorbar points."""

def __init__(self, plotter=None, brain=None, name=None):
self.plotter = plotter
self.brain = brain
self.name = name
self.callback = {
"fmin": brain.update_fmin,
"fmid": brain.update_fmid,
"fmax": brain.update_fmax
}
self.last_update = time.time()

def __call__(self, value):
"""Update the colorbar sliders."""
keys = ('fmin', 'fmid', 'fmax')
vals = {key: self.brain._data[key] for key in keys}
reps = {key: None for key in keys}
for slider in self.plotter.slider_widgets:
name = getattr(slider, "name", None)
if name is not None:
reps[name] = slider.GetRepresentation()
if self.name == "fmin" and reps["fmin"] is not None:
if vals['fmax'] < value:
self.brain.update_fmax(value)
reps['fmax'].SetValue(value)
if vals['fmid'] < value:
self.brain.update_fmid(value)
reps['fmid'].SetValue(value)
reps['fmin'].SetValue(value)
elif self.name == "fmid" and reps['fmid'] is not None:
if vals['fmin'] > value:
self.brain.update_fmin(value)
reps['fmin'].SetValue(value)
if vals['fmax'] < value:
self.brain.update_fmax(value)
reps['fmax'].SetValue(value)
reps['fmid'].SetValue(value)
elif self.name == "fmax" and reps['fmax'] is not None:
if vals['fmin'] > value:
self.brain.update_fmin(value)
reps['fmin'].SetValue(value)
if vals['fmid'] > value:
self.brain.update_fmid(value)
reps['fmid'].SetValue(value)
reps['fmax'].SetValue(value)
if time.time() > self.last_update + 1. / 60.:
self.callback[self.name](value)
self.last_update = time.time()


class _TimeViewer(object):
"""Class to interact with _Brain."""

Expand All @@ -66,20 +121,20 @@ def __init__(self, brain):

# smoothing slider
default_smoothing_value = 7
set_smoothing = IntSlider(
self.set_smoothing = IntSlider(
plotter=self.plotter,
callback=brain.set_data_smoothing,
name="smoothing"
)
smoothing_slider = self.plotter.add_slider_widget(
set_smoothing,
self.set_smoothing,
value=default_smoothing_value,
rng=[0, 15], title="smoothing",
pointa=(0.82, 0.90),
pointb=(0.98, 0.90)
)
smoothing_slider.name = 'smoothing'
set_smoothing(default_smoothing_value)
self.set_smoothing(default_smoothing_value)

# orientation slider
orientation = [
Expand Down Expand Up @@ -122,30 +177,48 @@ def __init__(self, brain):
# colormap slider
scaling_limits = [0.2, 2.0]
fmin = brain._data["fmin"]
self.update_fmin = BumpColorbarPoints(
plotter=self.plotter,
brain=brain,
name="fmin"
)
fmin_slider = self.plotter.add_slider_widget(
brain.update_fmin,
self.update_fmin,
value=fmin,
rng=_get_range(fmin, scaling_limits), title="fmin",
rng=_get_range(brain), title="fmin",
pointa=(0.82, 0.26),
pointb=(0.98, 0.26)
pointb=(0.98, 0.26),
event_type="always",
)
fmin_slider.name = "fmin"
fmid = brain._data["fmid"]
self.update_fmid = BumpColorbarPoints(
plotter=self.plotter,
brain=brain,
name="fmid",
)
fmid_slider = self.plotter.add_slider_widget(
brain.update_fmid,
self.update_fmid,
value=fmid,
rng=_get_range(fmid, scaling_limits), title="fmid",
rng=_get_range(brain), title="fmid",
pointa=(0.82, 0.42),
pointb=(0.98, 0.42)
pointb=(0.98, 0.42),
event_type="always",
)
fmid_slider.name = "fmid"
fmax = brain._data["fmax"]
self.update_fmax = BumpColorbarPoints(
plotter=self.plotter,
brain=brain,
name="fmax",
)
fmax_slider = self.plotter.add_slider_widget(
brain.update_fmax,
self.update_fmax,
value=fmax,
rng=_get_range(fmax, scaling_limits), title="fmax",
rng=_get_range(brain), title="fmax",
pointa=(0.82, 0.58),
pointb=(0.98, 0.58)
pointb=(0.98, 0.58),
event_type="always",
)
fmax_slider.name = "fmax"
update_fscale = UpdateColorbarScale(
Expand Down Expand Up @@ -194,5 +267,6 @@ def _set_slider_style(slider, show_label=True):
slider_rep.ShowSliderLabelOff()


def _get_range(val, rng):
return [val * rng[0], val * rng[1]]
def _get_range(brain):
val = np.abs(brain._data['array'])
return [np.min(val), np.max(val)]
Loading

0 comments on commit 25bf5db

Please sign in to comment.