diff --git a/spinn_machine/spinn_machine.cfg b/spinn_machine/spinn_machine.cfg index 5cc2cee1..7c9822f1 100644 --- a/spinn_machine/spinn_machine.cfg +++ b/spinn_machine/spinn_machine.cfg @@ -1,4 +1,9 @@ -# Adds or overwrites values in SpiNNUtils/spinn_utilities/spinn_utilities.cfg +# DO NOT EDIT! +# The are the default values +# Edit the cfg in your home directory to change your preferences +# Add / Edit a cfg in the run directory for script specific changes + +# Adds to values in SpiNNUtils/spinn_utilities/spinn_utilities.cfg [Machine] version = None @@ -6,8 +11,6 @@ version = None versions = None width = None height = None -# Note: if json_path is set all other configs for virtual boards are ignored -json_path = None # Can decrease actual but never increase. Used for testing max_sdram_allowed_per_chip = None @@ -42,3 +45,8 @@ max_machine_core = None # Urls for servers remote_spinnaker_url = None spalloc_server = None +# machine name is typically a URL and then version is required +machine_name = None + +# If using virtual_board both width and height must be set +virtual_board = False diff --git a/spinn_machine/version/abstract_version.py b/spinn_machine/version/abstract_version.py index 6ff2cd79..73f9c763 100644 --- a/spinn_machine/version/abstract_version.py +++ b/spinn_machine/version/abstract_version.py @@ -527,3 +527,6 @@ def version_parse_cores_string(self, core_string: str) -> Iterable[int]: :rtype: list(int) """ raise NotImplementedError + + def __hash__(self): + return self.number diff --git a/spinn_machine/version/version_201.py b/spinn_machine/version/version_201.py index 8867b0dc..ea84696d 100644 --- a/spinn_machine/version/version_201.py +++ b/spinn_machine/version/version_201.py @@ -89,3 +89,6 @@ def spinnaker_links(self) -> List[Tuple[int, int, int]]: @overrides(VersionSpin2.fpga_links) def fpga_links(self) -> List[Tuple[int, int, int, int, int]]: return [] + + def __eq__(self, other): + return isinstance(other, Version201) diff --git a/spinn_machine/version/version_248.py b/spinn_machine/version/version_248.py index c6ae7fc6..50bc4ad7 100644 --- a/spinn_machine/version/version_248.py +++ b/spinn_machine/version/version_248.py @@ -68,3 +68,6 @@ def spinnaker_links(self) -> List[Tuple[int, int, int]]: def fpga_links(self) -> List[Tuple[int, int, int, int, int]]: # TODO return [] + + def __eq__(self, other): + return isinstance(other, Version248) diff --git a/spinn_machine/version/version_3.py b/spinn_machine/version/version_3.py index 9d82f2a0..4cdfe41e 100644 --- a/spinn_machine/version/version_3.py +++ b/spinn_machine/version/version_3.py @@ -85,3 +85,6 @@ def spinnaker_links(self) -> List[Tuple[int, int, int]]: @overrides(VersionSpin1.fpga_links) def fpga_links(self) -> List[Tuple[int, int, int, int, int]]: return [] + + def __eq__(self, other): + return isinstance(other, Version3) diff --git a/spinn_machine/version/version_5.py b/spinn_machine/version/version_5.py index 519cbefd..70711600 100644 --- a/spinn_machine/version/version_5.py +++ b/spinn_machine/version/version_5.py @@ -82,3 +82,6 @@ def fpga_links(self) -> List[Tuple[int, int, int, int, int]]: (7, 5, 0, 2, 12), (7, 5, 1, 2, 11), (7, 6, 0, 2, 10), (7, 6, 1, 2, 9), (7, 7, 0, 2, 8), (7, 7, 1, 2, 7), (7, 7, 2, 2, 6)] + + def __eq__(self, other): + return isinstance(other, Version5) diff --git a/spinn_machine/version/version_factory.py b/spinn_machine/version/version_factory.py index e655ff10..fc19f07f 100644 --- a/spinn_machine/version/version_factory.py +++ b/spinn_machine/version/version_factory.py @@ -15,9 +15,9 @@ from __future__ import annotations import logging import sys -from typing import TYPE_CHECKING +from typing import Optional, TYPE_CHECKING from spinn_utilities.config_holder import ( - get_config_int_or_none, get_config_str_or_none) + get_config_bool, get_config_int_or_none, get_config_str_or_none) from spinn_utilities.log import FormatAdapter from spinn_machine.exceptions import SpinnMachineException from .version_strings import VersionStrings @@ -41,13 +41,47 @@ def version_factory() -> AbstractVersion: :return: A subclass of AbstractVersion :raises SpinnMachineException: If the cfg version is not set correctly """ - # Delayed import to avoid circular imports - # pylint: disable=import-outside-toplevel - from .version_3 import Version3 - from .version_5 import Version5 - from .version_201 import Version201 - from .version_248 import Version248 + cfg_version = _get_cfg_version() + url_version = _get_url_version() + size_version = _get_size_version() + + version: Optional[AbstractVersion] = None + if cfg_version is None: + if url_version is None: + version = None + else: + version = _number_to_version(url_version) + else: + if url_version is None: + version = _number_to_version(cfg_version) + else: + version_cfg = _number_to_version(cfg_version) + version_url = _number_to_version(url_version) + if version_cfg == version_url: + version = version_cfg + else: + raise_version_error("Incorrect version", cfg_version) + if size_version is None: + if version is None: + raise_version_error("No version", None) + else: + return version + else: + if version is None: + logger.warning("Please add a version to your cfg file.") + return _number_to_version(size_version) + else: + version_sized = _number_to_version(size_version) + if version == version_sized: + return version + else: + raise SpinnMachineException( + "cfg width and height do not match other cfg setting.") + raise SpinnMachineException("Should not get here") + + +def _get_cfg_version() -> Optional[int]: version = get_config_int_or_none("Machine", "version") versions = get_config_str_or_none("Machine", "versions") if versions is not None: @@ -59,6 +93,76 @@ def version_factory() -> AbstractVersion: # Use the fact that we run actions against different python versions minor = sys.version_info.minor version = options[minor % len(options)] + if version is None: + logger.warning( + "The cfg has no version. This is deprecated! Please add a version") + return version + + +def _get_url_version() -> Optional[int]: + spalloc_server = get_config_str_or_none("Machine", "spalloc_server") + remote_spinnaker_url = get_config_str_or_none( + "Machine", "remote_spinnaker_url") + machine_name = get_config_str_or_none("Machine", "machineName") + virtual_board = get_config_bool("Machine", "virtual_board") + + if spalloc_server is not None: + if remote_spinnaker_url is not None: + raise SpinnMachineException( + "Both spalloc_server and remote_spinnaker_url " + "specified in cfg") + if machine_name is not None: + raise SpinnMachineException( + "Both spalloc_server and machine_name specified in cfg") + if virtual_board: + raise SpinnMachineException( + "Both spalloc_server and virtual_board specified in cfg") + return 5 + + if remote_spinnaker_url is not None: + if machine_name is not None: + raise SpinnMachineException( + "Both remote_spinnaker_url and machine_name specified in cfg") + if virtual_board: + raise SpinnMachineException( + "Both remote_spinnaker_url and virtual_board specified in cfg") + return 5 + + if machine_name is not None: + if virtual_board: + raise SpinnMachineException( + "Both machine_name and virtual_board specified in cfg") + + return None + + +def _get_size_version() -> Optional[int]: + height = get_config_int_or_none("Machine", "height") + width = get_config_int_or_none("Machine", "width") + if height is None: + if width is None: + return None + else: + raise SpinnMachineException("cfg has width but not height") + else: + if width is None: + raise SpinnMachineException("cfg has height but not width") + else: + if height == width == 2: + return 3 + elif height == width == 1: + return 201 + # if width and height are valid checked later + return 5 + + +def _number_to_version(version: int): + # Delayed import to avoid circular imports + # pylint: disable=import-outside-toplevel + from .version_3 import Version3 + from .version_5 import Version5 + from .version_201 import Version201 + from .version_248 import Version248 if version in [2, 3]: return Version3() @@ -72,29 +176,28 @@ def version_factory() -> AbstractVersion: if version == SPIN2_48CHIP: return Version248() - spalloc_server = get_config_str_or_none("Machine", "spalloc_server") - if spalloc_server is not None: - return Version5() + raise SpinnMachineException(f"Unexpected cfg [Machine]version {version}") - remote_spinnaker_url = get_config_str_or_none( - "Machine", "remote_spinnaker_url") - if remote_spinnaker_url is not None: - return Version5() +def raise_version_error(error: str, version: Optional[int]): + """ + Collects main cfg values and raises an exception + + :param str error: message for the exception + :param version: version claimed + :type version: int or None + :raises SpinnMachineException: Always! + """ height = get_config_int_or_none("Machine", "height") width = get_config_int_or_none("Machine", "width") - if height is not None and width is not None: - logger.info("Your cfg file does not have a version") - if height == width == 2: - return Version3() - elif height == width == 1: - return Version201() - return Version5() - if version is None: - machine_name = get_config_str_or_none("Machine", "machineName") - raise SpinnMachineException( - "cfg [Machine]version {version} is None. " - f"Other cfg settings are {machine_name=} {spalloc_server=}, " - f"{remote_spinnaker_url=} {width=} {height=}") - raise SpinnMachineException(f"Unexpected cfg [Machine]version {version}") + spalloc_server = get_config_str_or_none("Machine", + "spalloc_server") + remote_spinnaker_url = get_config_str_or_none( + "Machine", "remote_spinnaker_url") + machine_name = get_config_str_or_none("Machine", "machineName") + virtual_board = get_config_bool("Machine", "virtual_board") + raise SpinnMachineException( + f"{error} with cfg [Machine] values {version=}, " + f"{machine_name=}, {spalloc_server=}, {remote_spinnaker_url=}, " + f"{virtual_board=}, {width=}, and {height=}") diff --git a/unittests/test_version_factory.py b/unittests/test_version_factory.py new file mode 100644 index 00000000..0f74f2a2 --- /dev/null +++ b/unittests/test_version_factory.py @@ -0,0 +1,95 @@ +# Copyright (c) 2014 The University of Manchester +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from spinn_utilities.config_holder import set_config +from spinn_machine.config_setup import unittest_setup +from spinn_machine.exceptions import ( + SpinnMachineException) +from spinn_machine.version.version_5 import Version5 +from spinn_machine.version.version_factory import version_factory + + +class TestVersionFactory(unittest.TestCase): + + def setUp(self): + unittest_setup() + + def test_no_info(self): + with self.assertRaises(SpinnMachineException): + version_factory() + + def test_bad_spalloc_3(self): + set_config("Machine", "version", 3) + set_config("Machine", "spalloc_server", "Somewhere") + with self.assertRaises(SpinnMachineException): + version_factory() + + def test_ok_spalloc_4(self): + set_config("Machine", "version", 4) + set_config("Machine", "spalloc_server", "Somewhere") + version = version_factory() + self.assertEqual(Version5(), version) + + def test_ok_spalloc(self): + # warning this behaviour may break if spalloc ever support spin2 + set_config("Machine", "spalloc_server", "Somewhere") + version = version_factory() + self.assertEqual(Version5(), version) + + def test_ok_remote_5(self): + set_config("Machine", "version", 4) + set_config("Machine", "remote_spinnaker_url", "Somewhere") + version = version_factory() + self.assertEqual(Version5(), version) + + def test_bad_spalloc_and_remote(self): + set_config("Machine", "spalloc_server", "Somewhere") + set_config("Machine", "remote_spinnaker_url", "Somewhere") + with self.assertRaises(SpinnMachineException): + version_factory() + + def test_bad_spalloc_and_name(self): + set_config("Machine", "spalloc_server", "Somewhere") + set_config("Machine", "machine_name", "Somewhere") + with self.assertRaises(SpinnMachineException): + version_factory() + + def test_bad_spalloc_and_virtual(self): + set_config("Machine", "spalloc_server", "Somewhere") + set_config("Machine", "virtual_board", "True") + with self.assertRaises(SpinnMachineException): + version_factory() + + def test_bad_remote_and_name(self): + set_config("Machine", "remote_spinnaker_url", "Somewhere") + set_config("Machine", "machine_name", "Somewhere") + with self.assertRaises(SpinnMachineException): + version_factory() + + def test_bad_remote_and_virtual(self): + set_config("Machine", "remote_spinnaker_url", "Somewhere") + set_config("Machine", "virtual_board", "True") + with self.assertRaises(SpinnMachineException): + version_factory() + + def test_bad_name_and_virtual(self): + set_config("Machine", "machine_name", "Somewhere") + set_config("Machine", "virtual_board", "True") + with self.assertRaises(SpinnMachineException): + version_factory() + + +if __name__ == '__main__': + unittest.main()