From 3fecfbbc2ff193396d7f62f3694571876b489652 Mon Sep 17 00:00:00 2001 From: Mathieu Scheltienne Date: Mon, 20 Nov 2023 17:08:22 +0100 Subject: [PATCH] Fix parsing of incomplete StreamInfo when connecting a Stream (#179) --- mne_lsl/lsl/stream_info.py | 56 +++++++++++------------ mne_lsl/stream/stream_lsl.py | 3 ++ mne_lsl/stream/tests/test_stream_lsl.py | 60 +++++++++++++++++++++++++ mne_lsl/utils/_docs.py | 2 +- 4 files changed, 92 insertions(+), 29 deletions(-) diff --git a/mne_lsl/lsl/stream_info.py b/mne_lsl/lsl/stream_info.py index 2bb29e32c..20e41a2ed 100644 --- a/mne_lsl/lsl/stream_info.py +++ b/mne_lsl/lsl/stream_info.py @@ -346,37 +346,37 @@ def get_channel_info(self) -> Info: loc_array.append(value) locs.append(loc_array) ch = ch.next_sibling() - locs = np.array(locs) + locs = ( + np.array([[np.nan] * 12] * self.n_channels) + if len(locs) == 0 + else np.array(locs) + ) with info._unlock(update_redundant=True): - for k, (kind, coil_type, coord_frame, range_cal, loc) in enumerate( - zip(kinds, coil_types, coord_frames, range_cals, locs) + for var, name, fiff_named in ( + (kinds, "kind", _ch_kind_named), + (coil_types, "coil_type", _ch_coil_type_named), + (coord_frames, "coord_frame", _coord_frame_named), ): - kind = _BaseStreamInfo._get_fiff_int_named(kind, "kind", _ch_kind_named) - if kind is not None: - info["chs"][k]["kind"] = kind - - coil_type = _BaseStreamInfo._get_fiff_int_named( - coil_type, "coil_type", _ch_coil_type_named - ) - if coil_type is not None: - info["chs"][k]["coil_type"] = coil_type - - coord_frame = _BaseStreamInfo._get_fiff_int_named( - coord_frame, "coord_frame", _coord_frame_named - ) - if coord_frame is not None: - info["chs"][k]["coord_frame"] = coord_frame - - if range_cal is not None: - try: - info["chs"][k]["range"] = 1.0 - info["chs"][k]["cal"] = float(range_cal) - except ValueError: - logger.warning( - "Could not cast 'range_cal' factor %s to float.", range_cal - ) - + if var is None: + continue + for k, value in enumerate(var): + value = _BaseStreamInfo._get_fiff_int_named(value, name, fiff_named) + if value is not None: + info["chs"][k][name] = value + + if range_cals is not None: + for k, range_cal in enumerate(range_cals): + if range_cal is not None: + try: + info["chs"][k]["range"] = 1.0 + info["chs"][k]["cal"] = float(range_cal) + except ValueError: + logger.warning( + "Could not cast 'range_cal' factor %s to float.", + range_cal, + ) + for k, loc in enumerate(locs): info["chs"][k]["loc"] = loc # filters diff --git a/mne_lsl/stream/stream_lsl.py b/mne_lsl/stream/stream_lsl.py index fae47ecb3..633d781c0 100644 --- a/mne_lsl/stream/stream_lsl.py +++ b/mne_lsl/stream/stream_lsl.py @@ -133,6 +133,7 @@ def connect( if processing_flags is not None and ( processing_flags == "threadsafe" or "threadsafe" in processing_flags ): + self._reset_variables() raise ValueError( "The 'threadsafe' processing flag should not be provided for an " "MNE-LSL Stream. If you require access to the underlying StreamInlet " @@ -144,12 +145,14 @@ def connect( # resolve and connect to available streams sinfos = resolve_streams(timeout, self._name, self._stype, self._source_id) if len(sinfos) != 1: + self._reset_variables() raise RuntimeError( "The provided arguments 'name', 'stype', and 'source_id' do not " f"uniquely identify an LSL stream. {len(sinfos)} were found: " f"{[(sinfo.name, sinfo.stype, sinfo.source_id) for sinfo in sinfos]}." ) if sinfos[0].dtype == "string": + self._reset_variables() raise RuntimeError( "The Stream class is designed for numerical types. It does not support " "string LSL streams. Please use a mne_lsl.lsl.StreamInlet directly to " diff --git a/mne_lsl/stream/tests/test_stream_lsl.py b/mne_lsl/stream/tests/test_stream_lsl.py index 930fcd2d8..bdae3fdef 100644 --- a/mne_lsl/stream/tests/test_stream_lsl.py +++ b/mne_lsl/stream/tests/test_stream_lsl.py @@ -21,6 +21,7 @@ from mne.io.pick import _picks_to_idx from mne_lsl import logger +from mne_lsl.lsl import StreamInfo, StreamOutlet from mne_lsl.stream import StreamLSL as Stream from mne_lsl.utils._tests import match_stream_and_raw_data from mne_lsl.utils.logs import _use_log_level @@ -106,6 +107,8 @@ def test_stream(mock_lsl_stream, acquisition_delay, raw): ) # dtype assert stream.dtype == stream.sinfo.dtype + # compensation grade + assert stream.compensation_grade is None # disconnect stream.disconnect() @@ -133,6 +136,8 @@ def test_stream_invalid(): Stream(1, stype=101) with pytest.raises(TypeError, match="must be an instance of str"): Stream(1, source_id=101) + with pytest.raises(ValueError, match="must be a positive number"): + Stream(bufsize=2).connect(acquisition_delay=-1) def test_stream_connection_no_args(mock_lsl_stream): @@ -557,3 +562,58 @@ def _sleep_until_new_data(acq_delay, player): 1.1 * (player.chunk_size / player.info["sfreq"]), ) ) + + +def test_stream_str(close_io): + """Test a stream on a string source.""" + sinfo = StreamInfo("test_stream_str", "gaze", 1, 100, "string", "pytest") + outlet = StreamOutlet(sinfo) + assert outlet.dtype == "string" + with pytest.raises( + RuntimeError, match="Stream class is designed for numerical types" + ): + Stream(bufsize=2, name="test_stream_str").connect() + close_io() + + +def test_stream_processing_flags(close_io): + """Test a stream connection processing flags.""" + sinfo = StreamInfo("test_stream_processing_flags", "gaze", 1, 100, "int8", "pytest") + outlet = StreamOutlet(sinfo) + assert outlet.dtype == np.int8 + stream = Stream(bufsize=2, name="test_stream_processing_flags") + assert not stream.connected + with pytest.raises( + ValueError, match="'threadsafe' processing flag should not be provided" + ): + stream.connect(processing_flags=("clocksync", "threadsafe")) + assert not stream.connected + stream.connect(processing_flags="all") + assert stream.connected + stream.disconnect() + assert not stream.connected + close_io() + + +def test_stream_irregularly_sampled(close_io): + """Test a stream with an irregular sampling rate.""" + sinfo = StreamInfo( + "test_stream_irregularly_sampled", "gaze", 1, 0, "int8", "pytest" + ) + outlet = StreamOutlet(sinfo) + stream = Stream(bufsize=10, name="test_stream_irregularly_sampled") + stream.connect() + time.sleep(0.1) # give a bit of time to the stream to acquire the first chunks + assert stream.connected + data, _ = stream.get_data() + expected = np.zeros(stream.n_buffer, dtype=stream.dtype) + assert_allclose(data.squeeze(), expected) + outlet.push_sample(np.array([1])) + time.sleep(0.01) + data, _ = stream.get_data() + expected[-1] = 1 + assert_allclose(data.squeeze(), expected) + with pytest.raises(RuntimeError, match="with an irregular sampling rate."): + stream._check_connected_and_regular_sampling("test") + stream.disconnect() + close_io() diff --git a/mne_lsl/utils/_docs.py b/mne_lsl/utils/_docs.py index a499e8e0d..16ee568cd 100644 --- a/mne_lsl/utils/_docs.py +++ b/mne_lsl/utils/_docs.py @@ -42,7 +42,7 @@ Size of the buffer keeping track of the data received from the stream. If the stream sampling rate ``sfreq`` is regular, ``bufsize`` is expressed in seconds. The buffer will hold the last ``bufsize * sfreq`` samples (ceiled). - If the strean sampling sampling rate ``sfreq`` is irregular, ``bufsize`` is + If the stream sampling rate ``sfreq`` is irregular, ``bufsize`` is expressed in samples. The buffer will hold the last ``bufsize`` samples.""" # -----------------------------------------------