Skip to content

Commit

Permalink
Fix parsing of incomplete StreamInfo when connecting a Stream (#179)
Browse files Browse the repository at this point in the history
  • Loading branch information
mscheltienne committed Nov 20, 2023
1 parent dd4ead0 commit 3fecfbb
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 29 deletions.
56 changes: 28 additions & 28 deletions mne_lsl/lsl/stream_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,37 +346,37 @@ def get_channel_info(self) -> Info:
loc_array.append(value)
locs.append(loc_array)
ch = ch.next_sibling()
locs = np.array(locs)
locs = (
np.array([[np.nan] * 12] * self.n_channels)
if len(locs) == 0
else np.array(locs)
)

with info._unlock(update_redundant=True):
for k, (kind, coil_type, coord_frame, range_cal, loc) in enumerate(
zip(kinds, coil_types, coord_frames, range_cals, locs)
for var, name, fiff_named in (
(kinds, "kind", _ch_kind_named),
(coil_types, "coil_type", _ch_coil_type_named),
(coord_frames, "coord_frame", _coord_frame_named),
):
kind = _BaseStreamInfo._get_fiff_int_named(kind, "kind", _ch_kind_named)
if kind is not None:
info["chs"][k]["kind"] = kind

coil_type = _BaseStreamInfo._get_fiff_int_named(
coil_type, "coil_type", _ch_coil_type_named
)
if coil_type is not None:
info["chs"][k]["coil_type"] = coil_type

coord_frame = _BaseStreamInfo._get_fiff_int_named(
coord_frame, "coord_frame", _coord_frame_named
)
if coord_frame is not None:
info["chs"][k]["coord_frame"] = coord_frame

if range_cal is not None:
try:
info["chs"][k]["range"] = 1.0
info["chs"][k]["cal"] = float(range_cal)
except ValueError:
logger.warning(
"Could not cast 'range_cal' factor %s to float.", range_cal
)

if var is None:
continue
for k, value in enumerate(var):
value = _BaseStreamInfo._get_fiff_int_named(value, name, fiff_named)
if value is not None:
info["chs"][k][name] = value

if range_cals is not None:
for k, range_cal in enumerate(range_cals):
if range_cal is not None:
try:
info["chs"][k]["range"] = 1.0
info["chs"][k]["cal"] = float(range_cal)
except ValueError:
logger.warning(
"Could not cast 'range_cal' factor %s to float.",
range_cal,
)
for k, loc in enumerate(locs):
info["chs"][k]["loc"] = loc

# filters
Expand Down
3 changes: 3 additions & 0 deletions mne_lsl/stream/stream_lsl.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def connect(
if processing_flags is not None and (
processing_flags == "threadsafe" or "threadsafe" in processing_flags
):
self._reset_variables()
raise ValueError(
"The 'threadsafe' processing flag should not be provided for an "
"MNE-LSL Stream. If you require access to the underlying StreamInlet "
Expand All @@ -144,12 +145,14 @@ def connect(
# resolve and connect to available streams
sinfos = resolve_streams(timeout, self._name, self._stype, self._source_id)
if len(sinfos) != 1:
self._reset_variables()
raise RuntimeError(
"The provided arguments 'name', 'stype', and 'source_id' do not "
f"uniquely identify an LSL stream. {len(sinfos)} were found: "
f"{[(sinfo.name, sinfo.stype, sinfo.source_id) for sinfo in sinfos]}."
)
if sinfos[0].dtype == "string":
self._reset_variables()
raise RuntimeError(
"The Stream class is designed for numerical types. It does not support "
"string LSL streams. Please use a mne_lsl.lsl.StreamInlet directly to "
Expand Down
60 changes: 60 additions & 0 deletions mne_lsl/stream/tests/test_stream_lsl.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from mne.io.pick import _picks_to_idx

from mne_lsl import logger
from mne_lsl.lsl import StreamInfo, StreamOutlet
from mne_lsl.stream import StreamLSL as Stream
from mne_lsl.utils._tests import match_stream_and_raw_data
from mne_lsl.utils.logs import _use_log_level
Expand Down Expand Up @@ -106,6 +107,8 @@ def test_stream(mock_lsl_stream, acquisition_delay, raw):
)
# dtype
assert stream.dtype == stream.sinfo.dtype
# compensation grade
assert stream.compensation_grade is None
# disconnect
stream.disconnect()

Expand Down Expand Up @@ -133,6 +136,8 @@ def test_stream_invalid():
Stream(1, stype=101)
with pytest.raises(TypeError, match="must be an instance of str"):
Stream(1, source_id=101)
with pytest.raises(ValueError, match="must be a positive number"):
Stream(bufsize=2).connect(acquisition_delay=-1)


def test_stream_connection_no_args(mock_lsl_stream):
Expand Down Expand Up @@ -557,3 +562,58 @@ def _sleep_until_new_data(acq_delay, player):
1.1 * (player.chunk_size / player.info["sfreq"]),
)
)


def test_stream_str(close_io):
"""Test a stream on a string source."""
sinfo = StreamInfo("test_stream_str", "gaze", 1, 100, "string", "pytest")
outlet = StreamOutlet(sinfo)
assert outlet.dtype == "string"
with pytest.raises(
RuntimeError, match="Stream class is designed for numerical types"
):
Stream(bufsize=2, name="test_stream_str").connect()
close_io()


def test_stream_processing_flags(close_io):
"""Test a stream connection processing flags."""
sinfo = StreamInfo("test_stream_processing_flags", "gaze", 1, 100, "int8", "pytest")
outlet = StreamOutlet(sinfo)
assert outlet.dtype == np.int8
stream = Stream(bufsize=2, name="test_stream_processing_flags")
assert not stream.connected
with pytest.raises(
ValueError, match="'threadsafe' processing flag should not be provided"
):
stream.connect(processing_flags=("clocksync", "threadsafe"))
assert not stream.connected
stream.connect(processing_flags="all")
assert stream.connected
stream.disconnect()
assert not stream.connected
close_io()


def test_stream_irregularly_sampled(close_io):
"""Test a stream with an irregular sampling rate."""
sinfo = StreamInfo(
"test_stream_irregularly_sampled", "gaze", 1, 0, "int8", "pytest"
)
outlet = StreamOutlet(sinfo)
stream = Stream(bufsize=10, name="test_stream_irregularly_sampled")
stream.connect()
time.sleep(0.1) # give a bit of time to the stream to acquire the first chunks
assert stream.connected
data, _ = stream.get_data()
expected = np.zeros(stream.n_buffer, dtype=stream.dtype)
assert_allclose(data.squeeze(), expected)
outlet.push_sample(np.array([1]))
time.sleep(0.01)
data, _ = stream.get_data()
expected[-1] = 1
assert_allclose(data.squeeze(), expected)
with pytest.raises(RuntimeError, match="with an irregular sampling rate."):
stream._check_connected_and_regular_sampling("test")
stream.disconnect()
close_io()
2 changes: 1 addition & 1 deletion mne_lsl/utils/_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
Size of the buffer keeping track of the data received from the stream. If
the stream sampling rate ``sfreq`` is regular, ``bufsize`` is expressed in
seconds. The buffer will hold the last ``bufsize * sfreq`` samples (ceiled).
If the strean sampling sampling rate ``sfreq`` is irregular, ``bufsize`` is
If the stream sampling rate ``sfreq`` is irregular, ``bufsize`` is
expressed in samples. The buffer will hold the last ``bufsize`` samples."""

# -----------------------------------------------
Expand Down

0 comments on commit 3fecfbb

Please sign in to comment.