Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MRG, ENH: Speed up epochs.copy #7968

Merged
merged 3 commits into from
Jul 7, 2020
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions doc/changes/latest.inc
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ Changelog

- Speed up raw data reading without preload in :func:`mne.io.read_raw_nirx` by `Eric Larson`_

- Speed up :meth:`mne.Epochs.copy` and :meth:`mne.Epochs.__getitem__` by avoiding copying immutable attributes by `Eric Larson`_

- Support for saving movies of source time courses (STCs) with ``brain.save_movie`` method and from graphical user interface by `Guillaume Favelier`_

- Add ``mri`` and ``show_orientation`` arguments to :func:`mne.viz.plot_bem` by `Eric Larson`_
Expand Down Expand Up @@ -174,6 +176,8 @@ Bug

- Fix bug with :class:`mne.Epochs` when metadata was not subselected properly when ``event_repeated='drop'`` by `Eric Larson`_

- Fix bug with :class:`mne.Epochs` where ``epochs.drop_log`` was a list of list of str rather than an immutable tuple of tuple of str (not meant to be changed by the user) by `Eric Larson`_

- Fix bug with :class:`mne.Report` where the BEM section could not be toggled by `Eric Larson`_

- Fix bug when using :meth:`mne.Epochs.crop` to exclude the baseline period would break :func:`mne.Epochs.save` / :func:`mne.read_epochs` round-trip by `Eric Larson`_
Expand Down
121 changes: 69 additions & 52 deletions mne/epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,8 @@
import operator
import os.path as op
import warnings
from distutils.version import LooseVersion

import numpy as np
import scipy

from .io.write import (start_file, start_block, end_file, end_block,
write_int, write_float, write_float_matrix,
Expand Down Expand Up @@ -271,6 +269,7 @@ def _handle_event_repeated(events, event_id, event_repeated, selection,

# Else, we have duplicates. Triage ...
_check_option('event_repeated', event_repeated, ['error', 'drop', 'merge'])
drop_log = list(drop_log)
if event_repeated == 'error':
raise RuntimeError('Event time samples were not unique. Consider '
'setting the `event_repeated` parameter."')
Expand All @@ -282,7 +281,7 @@ def _handle_event_repeated(events, event_id, event_repeated, selection,
new_selection = selection[u_ev_idxs]
drop_ev_idxs = np.setdiff1d(selection, new_selection)
for idx in drop_ev_idxs:
drop_log[idx].append('DROP DUPLICATE')
drop_log[idx] = drop_log[idx] + ('DROP DUPLICATE',)
selection = new_selection
elif event_repeated == 'merge':
logger.info('Multiple event values for single event times found. '
Expand All @@ -291,8 +290,9 @@ def _handle_event_repeated(events, event_id, event_repeated, selection,
_merge_events(events, event_id, selection)
drop_ev_idxs = np.setdiff1d(selection, new_selection)
for idx in drop_ev_idxs:
drop_log[idx].append('MERGE DUPLICATE')
drop_log[idx] = drop_log[idx] + ('MERGE DUPLICATE',)
selection = new_selection
drop_log = tuple(drop_log)

# Remove obsolete kv-pairs from event_id after handling
keys = new_events[:, 1:].flatten()
Expand Down Expand Up @@ -355,7 +355,7 @@ class BaseEpochs(ProjMixin, ContainsMixin, UpdateChannelsMixin, ShiftTimeMixin,
Iterable of indices of selected epochs. If ``None``, will be
automatically generated, corresponding to all non-zero events.
drop_log : list | None
List of lists of strings indicating which epochs have been marked to be
List of tuple of strings indicating which epochs have been marked to be
larsoner marked this conversation as resolved.
Show resolved Hide resolved
ignored.
filename : str | None
The filename (if the epochs are read from disk).
Expand Down Expand Up @@ -432,9 +432,10 @@ def __init__(self, info, data, events, event_id=None, tmin=-0.2, tmax=0.5,
% (selected.shape, selection.shape))
self.selection = selection
if drop_log is None:
self.drop_log = [list() if k in self.selection else ['IGNORED']
for k in range(max(len(self.events),
max(self.selection) + 1))]
self.drop_log = tuple(
() if k in self.selection else ('IGNORED',)
for k in range(max(len(self.events),
max(self.selection) + 1)))
else:
self.drop_log = drop_log

Expand Down Expand Up @@ -559,6 +560,9 @@ def _check_consistency(self):
assert len(self.drop_log) >= len(self.events)
assert hasattr(self, '_times_readonly')
assert not self.times.flags['WRITEABLE']
assert isinstance(self.drop_log, tuple)
assert all(isinstance(log, tuple) for log in self.drop_log)
assert all(isinstance(s, str) for log in self.drop_log for s in log)

def load_data(self):
"""Load the data if not already preloaded.
Expand Down Expand Up @@ -760,13 +764,13 @@ def _reject_setup(self, reject, flat):
def _is_good_epoch(self, data, verbose=None):
"""Determine if epoch is good."""
if isinstance(data, str):
return False, [data]
return False, (data,)
if data is None:
return False, ['NO_DATA']
return False, ('NO_DATA',)
n_times = len(self.times)
if data.shape[1] < n_times:
# epoch is too short ie at the end of the data
return False, ['TOO_SHORT']
return False, ('TOO_SHORT',)
if self.reject is None and self.flat is None:
return True, None
else:
Expand Down Expand Up @@ -1353,6 +1357,7 @@ def _get_data(self, out=True, picks=None, item=None, verbose=None):
# e.g., when calling drop_bad w/new params
good_idx = []
n_out = 0
drop_log = list(self.drop_log)
assert n_events == len(self.selection)
for idx, sel in enumerate(self.selection):
if self.preload: # from memory
Expand All @@ -1368,9 +1373,11 @@ def _get_data(self, out=True, picks=None, item=None, verbose=None):
epoch = self._project_epoch(epoch_noproj)

epoch_out = epoch_noproj if self._do_delayed_proj else epoch
is_good, offending_reason = self._is_good_epoch(epoch)
is_good, bad_tuple = self._is_good_epoch(epoch)
if not is_good:
self.drop_log[sel] += offending_reason
assert isinstance(bad_tuple, tuple)
assert all(isinstance(x, str) for x in bad_tuple)
drop_log[sel] = drop_log[sel] + bad_tuple
continue
good_idx.append(idx)

Expand All @@ -1383,6 +1390,8 @@ def _get_data(self, out=True, picks=None, item=None, verbose=None):
dtype=epoch_out.dtype, order='C')
data[n_out] = epoch_out
n_out += 1
self.drop_log = tuple(drop_log)
del drop_log

self._bad_dropped = True
logger.info("%d bad epochs dropped" % (n_events - len(good_idx)))
Expand Down Expand Up @@ -1543,13 +1552,21 @@ def copy(self):
epochs : instance of Epochs
A copy of the object.
"""
raw = self._raw
del self._raw
new = deepcopy(self)
self._raw = raw
new._raw = raw
new._set_times(new.times) # sets RO
return new
return deepcopy(self)

def __deepcopy__(self, memodict):
"""Make a deepcopy."""
cls = self.__class__
result = cls.__new__(cls)
for k, v in self.__dict__.items():
# drop_log is immutable and _raw is private (and problematic to
# deepcopy)
if k in ('drop_log', '_raw', '_times_readonly'):
memodict[id(v)] = v
else:
v = deepcopy(v, memodict)
result.__dict__[k] = v
return result

@verbose
def save(self, fname, split_size='2GB', fmt='single', overwrite=False,
Expand Down Expand Up @@ -1902,8 +1919,10 @@ def _drop_log_stats(drop_log, ignore=('IGNORED',)):
perc : float
Total percentage of epochs dropped.
"""
if not isinstance(drop_log, list) or not isinstance(drop_log[0], list):
raise ValueError('drop_log must be a list of lists')
if not isinstance(drop_log, tuple) or \
not all(isinstance(d, tuple) for d in drop_log) or \
not all(isinstance(s, str) for d in drop_log for s in d):
raise TypeError('drop_log must be a tuple of tuple of str')
perc = 100 * np.mean([len(d) > 0 for d in drop_log
if not any(r in ignore for r in d)])
return perc
Expand Down Expand Up @@ -2036,17 +2055,21 @@ class Epochs(BaseEpochs):
has been dropped, this attribute would be np.array([0, 2, 3]).
preload : bool
Indicates whether epochs are in memory.
drop_log : list of list
A list of the same length as the event array used to initialize the
drop_log : tuple of tuple
A tuple of the same length as the event array used to initialize the
Epochs object. If the i-th original event is still part of the
selection, drop_log[i] will be an empty list; otherwise it will be
a list of the reasons the event is not longer in the selection, e.g.:

'IGNORED' if it isn't part of the current subset defined by the user;
'NO_DATA' or 'TOO_SHORT' if epoch didn't contain enough data;
names of channels that exceeded the amplitude threshold;
'EQUALIZED_COUNTS' (see equalize_event_counts);
or 'USER' for user-defined reasons (see drop method).
selection, drop_log[i] will be an empty tuple; otherwise it will be
a tuple of the reasons the event is not longer in the selection, e.g.:

- 'IGNORED'
If it isn't part of the current subset defined by the user
- 'NO_DATA' or 'TOO_SHORT'
If epoch didn't contain enough data names of channels that exceeded
the amplitude threshold
- 'EQUALIZED_COUNTS'
See :meth:`~mne.Epochs.equalize_event_counts`
- 'USER'
For user-defined reasons (see :meth:`~mne.Epochs.drop`).
filename : str
The filename of the object.
times : ndarray
Expand Down Expand Up @@ -2380,13 +2403,6 @@ def _get_drop_indices(event_times, method):
return indices


def _fix_fill(fill):
"""Fix bug on old scipy."""
if LooseVersion(scipy.__version__) < LooseVersion('0.12'):
fill = fill[:, np.newaxis]
return fill


def _minimize_time_diff(t_shorter, t_longer):
"""Find a boolean mask to minimize timing differences."""
from scipy.interpolate import interp1d
Expand All @@ -2413,7 +2429,7 @@ def _minimize_time_diff(t_shorter, t_longer):
x2 = np.arange(len(t_longer) - ii - 1)
t_keeps = np.array([t_longer[km] for km in keep_mask])
longer_interp = interp1d(x2, t_keeps, axis=1,
fill_value=_fix_fill(t_keeps[:, -1]),
fill_value=t_keeps[:, -1],
**kwargs)
d1 = longer_interp(x1) - t_shorter
d2 = shorter_interp(x2) - t_keeps
Expand All @@ -2430,7 +2446,7 @@ def _is_good(e, ch_names, channel_type_idx, reject, flat, full_report=False,
If full_report=True, it will give True/False as well as a list of all
offending channels.
"""
bad_list = list()
bad_tuple = tuple()
has_printed = False
checkable = np.ones(len(ch_names), dtype=bool)
checkable[np.array([c in ignore_chs
Expand All @@ -2448,23 +2464,23 @@ def _is_good(e, ch_names, channel_type_idx, reject, flat, full_report=False,
checkable_idx))[0]

if len(idx_deltas) > 0:
ch_name = [ch_names[idx[i]] for i in idx_deltas]
bad_names = [ch_names[idx[i]] for i in idx_deltas]
if (not has_printed):
logger.info(' Rejecting %s epoch based on %s : '
'%s' % (t, name, ch_name))
'%s' % (t, name, bad_names))
has_printed = True
if not full_report:
return False
else:
bad_list.extend(ch_name)
bad_tuple += tuple(bad_names)

if not full_report:
return True
else:
if bad_list == []:
if bad_tuple == ():
return True, None
else:
return False, bad_list
return False, bad_tuple


def _read_one_epoch_file(f, tree, preload):
Expand Down Expand Up @@ -2541,7 +2557,7 @@ def _read_one_epoch_file(f, tree, preload):
selection = np.array(tag.data)
elif kind == FIFF.FIFF_MNE_EPOCHS_DROP_LOG:
tag = read_tag(fid, pos)
drop_log = json.loads(tag.data)
drop_log = tuple(tuple(x) for x in json.loads(tag.data))
elif kind == FIFF.FIFF_MNE_EPOCHS_REJECT_FLAT:
tag = read_tag(fid, pos)
reject_params = json.loads(tag.data)
Expand Down Expand Up @@ -2604,7 +2620,7 @@ def _read_one_epoch_file(f, tree, preload):
if selection is None:
selection = np.arange(len(events))
if drop_log is None:
drop_log = [[] for _ in range(len(events))]
drop_log = ((),) * len(events)

return (info, data, data_tag, events, event_id, metadata, tmin, tmax,
baseline, selection, drop_log, epoch_shape, cals, reject_params,
Expand Down Expand Up @@ -2744,12 +2760,13 @@ def __init__(self, fname, proj=True, preload=True,
assert len(drop_log) % len(fnames) == 0
step = len(drop_log) // len(fnames)
offsets = np.arange(step, len(drop_log) + 1, step)
drop_log = list(drop_log)
for i1, i2 in zip(offsets[:-1], offsets[1:]):
other_log = drop_log[i1:i2]
for k, (a, b) in enumerate(zip(drop_log, other_log)):
if a == ['IGNORED'] and b != ['IGNORED']:
if a == ('IGNORED',) and b != ('IGNORED',):
drop_log[k] = b
drop_log = drop_log[:step]
drop_log = tuple(drop_log[:step])

# call BaseEpochs constructor
super(EpochsFIF, self).__init__(
Expand Down Expand Up @@ -2949,7 +2966,7 @@ def _concatenate_epochs(epochs_list, with_data=True, add_offset=True):
baseline, tmin, tmax = out.baseline, out.tmin, out.tmax
info = deepcopy(out.info)
verbose = out.verbose
drop_log = deepcopy(out.drop_log)
drop_log = out.drop_log
event_id = deepcopy(out.event_id)
selection = out.selection
# offset is the last epoch + tmax + 10 second
Expand Down Expand Up @@ -2985,7 +3002,7 @@ def _concatenate_epochs(epochs_list, with_data=True, add_offset=True):
int((10 + tmax) * epochs.info['sfreq']))
events.append(evs)
selection = np.concatenate((selection, epochs.selection))
drop_log.extend(epochs.drop_log)
drop_log = drop_log + epochs.drop_log
event_id.update(epochs.event_id)
metadata.append(epochs.metadata)
events = np.concatenate(events, axis=0)
Expand Down
Loading