Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix parallel optimisation and increase coverage #299

Merged
merged 8 commits into from
Apr 22, 2024
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ codesigned binaries and source distributions via `sigstore-python`.

## Bug Fixes

- [#299](https://github.com/pybop-team/PyBOP/pull/299) - Bugfix multiprocessing support for Linux, MacOS, Windows (WSL) and improves coverage.
- [#270](https://github.com/pybop-team/PyBOP/pull/270) - Updates PR template.
- [#91](https://github.com/pybop-team/PyBOP/issues/91) - Adds a check on the number of parameters for CMAES and makes XNES the default optimiser.

Expand Down
18 changes: 18 additions & 0 deletions pybop/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,24 @@
import sys
from os import path

#
# Multiprocessing
#
try:
import multiprocessing as mp
if sys.platform == "win32":
mp.set_start_method("spawn")
else:
mp.set_start_method("fork")
except Exception as e: # pragma: no cover
error_message = (
"Multiprocessing context could not be set. "
"Continuing import without setting context.\n"
f"Error: {e}"
) # pragma: no cover
print(error_message) # pragma: no cover
pass # pragma: no cover

#
# Version info
#
Expand Down
4 changes: 2 additions & 2 deletions pybop/_optimisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,8 +234,8 @@ def _run_pints(self):

# For population based optimisers, don't use more workers than
# particles!
if isinstance(self._optimiser, pints.PopulationBasedOptimiser):
n_workers = min(n_workers, self._optimiser.population_size())
if isinstance(self.optimiser, pints.PopulationBasedOptimiser):
n_workers = min(n_workers, self.optimiser.population_size())
evaluator = pints.ParallelEvaluator(f, n_workers=n_workers)
else:
evaluator = pints.SequentialEvaluator(f)
Expand Down
6 changes: 6 additions & 0 deletions tests/integration/test_optimisation_options.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import sys

import numpy as np
import pytest

Expand Down Expand Up @@ -86,6 +88,10 @@ def test_optimisation_f_guessed(self, f_guessed, spm_costs):
parameterisation.set_max_iterations(125)
parameterisation.set_f_guessed_tracking(f_guessed)

# Set parallelisation if not on Windows
if sys.platform != "win32":
parameterisation.set_parallel(True)

initial_cost = parameterisation.cost(spm_costs.x0)
x, final_cost = parameterisation.run()

Expand Down
37 changes: 37 additions & 0 deletions tests/unit/test_import.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import importlib
import sys
from unittest.mock import patch

import pytest


class TestImport:
@pytest.mark.unit
def test_multiprocessing_init_non_win32(self, monkeypatch):
"""Test multiprocessing init on non-Windows platforms"""
monkeypatch.setattr(sys, "platform", "linux")
# Unload pybop and its sub-modules
self.unload_pybop()
with patch("multiprocessing.set_start_method") as mock_set_start_method:
importlib.import_module("pybop")
mock_set_start_method.assert_called_once_with("fork")

@pytest.mark.unit
def test_multiprocessing_init_win32(self, monkeypatch):
"""Test multiprocessing init on Windows"""
monkeypatch.setattr(sys, "platform", "win32")
self.unload_pybop()
with patch("multiprocessing.set_start_method") as mock_set_start_method:
importlib.import_module("pybop")
mock_set_start_method.assert_called_once_with("spawn")

def unload_pybop(self):
"""
Unload pybop and its sub-modules. Credit PyBaMM team:
https://github.com/pybamm-team/PyBaMM/blob/develop/tests/unit/test_util.py
BradyPlanden marked this conversation as resolved.
Show resolved Hide resolved
"""
# Unload pybop and its sub-modules
for module_name in list(sys.modules.keys()):
base_module_name = module_name.split(".")[0]
if base_module_name == "pybop":
sys.modules.pop(module_name)
Loading