Skip to content

Commit

Permalink
SlurmScheduler: fix bug in validation of job resources (#4555)
Browse files Browse the repository at this point in the history
The `SlurmJobResource` resource class used by the `SlurmScheduler`
plugin contained a bug in the `validate_resources` methods that would
cause a float value to be set for the `num_cores_per_mpiproc` field in
certain cases. This would cause the submit script to fail because SLURM
only accepts integers for the corresponding `--ncpus-per-task` flag.

The reason is that the code was incorrectly using `isinstance(_, int)`
to check that the divison of `num_cores_per_machine` over
`num_mpiprocs_per_machine` is an integer. In addition to the negation
missing in the conditional, this is not the correct way of checking
whether a division is an integer. Instead it should check that the value
is identical after it is cast to `int`.
  • Loading branch information
chrisjsewell committed Nov 11, 2020
1 parent ac4c881 commit c42a86b
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 97 deletions.
5 changes: 3 additions & 2 deletions aiida/schedulers/plugins/slurm.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def validate_resources(cls, **kwargs):
"""
resources = super().validate_resources(**kwargs)

# In this plugin we never used num_cores_per_machine so if it is not defined it is OK.
if resources.num_cores_per_machine is not None and resources.num_cores_per_mpiproc is not None:
if resources.num_cores_per_machine != resources.num_cores_per_mpiproc * resources.num_mpiprocs_per_machine:
raise ValueError(
Expand All @@ -130,13 +131,13 @@ def validate_resources(cls, **kwargs):
if resources.num_cores_per_machine < 1:
raise ValueError('num_cores_per_machine must be greater than or equal to one.')

# In this plugin we never used num_cores_per_machine so if it is not defined it is OK.
resources.num_cores_per_mpiproc = (resources.num_cores_per_machine / resources.num_mpiprocs_per_machine)
if isinstance(resources.num_cores_per_mpiproc, int):
if int(resources.num_cores_per_mpiproc) != resources.num_cores_per_mpiproc:
raise ValueError(
'`num_cores_per_machine` must be equal to `num_cores_per_mpiproc * num_mpiprocs_per_machine` and in'
' particular it should be a multiple of `num_cores_per_mpiproc` and/or `num_mpiprocs_per_machine`'
)
resources.num_cores_per_mpiproc = int(resources.num_cores_per_mpiproc)

return resources

Expand Down
209 changes: 114 additions & 95 deletions tests/schedulers/test_slurm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,17 @@
# For further information please visit http://www.aiida.net #
###########################################################################
"""Tests for the SLURM scheduler plugin."""
#pylint: disable=no-self-use,line-too-long
import unittest
import logging
import uuid
import datetime

import pytest

from aiida.schedulers.plugins.slurm import SlurmScheduler, JobState
from aiida.schedulers.plugins.slurm import SlurmJobResource, SlurmScheduler, JobState
from aiida.schedulers import SchedulerError

# pylint: disable=line-too-long
# job_id, state_raw, annotation, executing_host, username, number_nodes, number_cpus, allocated_machines, partition, time_limit, time_used, dispatch_time, job_name, submission_time
# See SlurmScheduler.fields
TEXT_SQUEUE_TO_TEST = """862540^^^PD^^^Dependency^^^n/a^^^user1^^^20^^^640^^^(Dependency)^^^normal^^^1-00:00:00^^^0:00^^^N/A^^^longsqw_L24_q_10_0^^^2013-05-22T01:41:11
Expand All @@ -37,9 +37,41 @@
JOBS_RUNNING = ['862538', '861352', '863553', '863554']


def test_resource_validation():
"""Tests to verify that resources are correctly validated."""
with pytest.raises(
ValueError,
match='At least two among `num_machines`, `num_mpiprocs_per_machine` or `tot_num_mpiprocs` must be specified.'
):
SlurmJobResource()

output = SlurmJobResource(num_machines=1, num_mpiprocs_per_machine=2)
assert output == {
'num_machines': 1,
'num_mpiprocs_per_machine': 2,
'num_cores_per_machine': None,
'num_cores_per_mpiproc': None,
'tot_num_mpiprocs': 2
}

with pytest.raises(
ValueError, match='`tot_num_mpiprocs` is not equal to `num_mpiprocs_per_machine \\* num_machines`.'
):
SlurmJobResource(num_cores_per_machine=1, tot_num_mpiprocs=1, num_mpiprocs_per_machine=2)

with pytest.raises(ValueError, match='num_cores_per_machine must be greater than or equal to one.'):
SlurmJobResource(num_machines=1, tot_num_mpiprocs=1, num_cores_per_machine=0)

with pytest.raises(
ValueError,
match='num_cores_per_machine` must be equal to `num_cores_per_mpiproc \\* num_mpiprocs_per_machine`'
):
SlurmJobResource(num_machines=1, tot_num_mpiprocs=4, num_cores_per_machine=3)


class TestParserSqueue(unittest.TestCase):
"""
Tests to verify if teh function _parse_joblist_output behave correctly
Tests to verify if the function _parse_joblist_output behave correctly
The tests is done parsing a string defined above, to be used offline
"""

Expand All @@ -58,37 +90,37 @@ def test_parse_common_joblist_output(self):

# The parameters are hard coded in the text to parse
job_parsed = len(job_list)
self.assertEqual(job_parsed, JOBS_ON_CLUSTER)
assert job_parsed == JOBS_ON_CLUSTER

job_running_parsed = len([j for j in job_list if j.job_state \
and j.job_state == JobState.RUNNING])
self.assertEqual(len(JOBS_RUNNING), job_running_parsed)
assert len(JOBS_RUNNING) == job_running_parsed

job_held_parsed = len([j for j in job_list if j.job_state and j.job_state == JobState.QUEUED_HELD])
self.assertEqual(JOBS_HELD, job_held_parsed)
assert JOBS_HELD == job_held_parsed

job_queued_parsed = len([j for j in job_list if j.job_state and j.job_state == JobState.QUEUED])
self.assertEqual(JOBS_QUEUED, job_queued_parsed)
assert JOBS_QUEUED == job_queued_parsed

parsed_running_users = [j.job_owner for j in job_list if j.job_state and j.job_state == JobState.RUNNING]
self.assertEqual(set(USERS_RUNNING), set(parsed_running_users))
assert set(USERS_RUNNING) == set(parsed_running_users)

parsed_running_jobs = [j.job_id for j in job_list if j.job_state and j.job_state == JobState.RUNNING]
self.assertEqual(set(JOBS_RUNNING), set(parsed_running_jobs))
assert set(JOBS_RUNNING) == set(parsed_running_jobs)

self.assertEqual(job_dict['863553'].requested_wallclock_time_seconds, 30 * 60) # pylint: disable=invalid-name
self.assertEqual(job_dict['863553'].wallclock_time_seconds, 29 * 60 + 29)
self.assertEqual(job_dict['863553'].dispatch_time, datetime.datetime(2013, 5, 23, 11, 44, 11))
self.assertEqual(job_dict['863553'].submission_time, datetime.datetime(2013, 5, 23, 10, 42, 11))
assert job_dict['863553'].requested_wallclock_time_seconds, 30 * 60 # pylint: disable=invalid-name
assert job_dict['863553'].wallclock_time_seconds, 29 * 60 + 29
assert job_dict['863553'].dispatch_time, datetime.datetime(2013, 5, 23, 11, 44, 11)
assert job_dict['863553'].submission_time, datetime.datetime(2013, 5, 23, 10, 42, 11)

self.assertEqual(job_dict['863100'].annotation, 'Resources')
self.assertEqual(job_dict['863100'].num_machines, 32)
self.assertEqual(job_dict['863100'].num_mpiprocs, 1024)
self.assertEqual(job_dict['863100'].queue_name, 'normal')
assert job_dict['863100'].annotation == 'Resources'
assert job_dict['863100'].num_machines == 32
assert job_dict['863100'].num_mpiprocs == 1024
assert job_dict['863100'].queue_name == 'normal'

self.assertEqual(job_dict['861352'].title, 'Pressure_PBEsol_0')
assert job_dict['861352'].title == 'Pressure_PBEsol_0'

self.assertEqual(job_dict['863554'].requested_wallclock_time_seconds, None) # pylint: disable=invalid-name
assert job_dict['863554'].requested_wallclock_time_seconds is None # pylint: disable=invalid-name

# allocated_machines is not implemented in this version of the plugin
# for j in job_list:
Expand All @@ -108,68 +140,55 @@ def test_parse_failed_squeue_output(self):
scheduler = SlurmScheduler()

# non-zero return value should raise
with self.assertRaises(SchedulerError):
_ = scheduler._parse_joblist_output(1, TEXT_SQUEUE_TO_TEST, '') # pylint: disable=protected-access
with pytest.raises(SchedulerError, match='squeue returned exit code 1'):
scheduler._parse_joblist_output(1, TEXT_SQUEUE_TO_TEST, '') # pylint: disable=protected-access

# non-empty stderr should be logged
with self.assertLogs(scheduler.logger, 'WARNING'):
_ = scheduler._parse_joblist_output(0, TEXT_SQUEUE_TO_TEST, 'error message') # pylint: disable=protected-access


class TestTimes(unittest.TestCase):
"""Test time parsing of SLURM scheduler plugin."""

def test_time_conversion(self):
"""
Test conversion of (relative) times.
From docs, acceptable time formats include
"minutes", "minutes:seconds", "hours:minutes:seconds",
"days-hours", "days-hours:minutes" and "days-hours:minutes:seconds".
"""
# pylint: disable=protected-access
scheduler = SlurmScheduler()
self.assertEqual(scheduler._convert_time('2'), 2 * 60)
self.assertEqual(scheduler._convert_time('02'), 2 * 60)

self.assertEqual(scheduler._convert_time('02:3'), 2 * 60 + 3)
self.assertEqual(scheduler._convert_time('02:03'), 2 * 60 + 3)

self.assertEqual(scheduler._convert_time('1:02:03'), 3600 + 2 * 60 + 3)
self.assertEqual(scheduler._convert_time('01:02:03'), 3600 + 2 * 60 + 3)

self.assertEqual(scheduler._convert_time('1-3'), 86400 + 3 * 3600)
self.assertEqual(scheduler._convert_time('01-3'), 86400 + 3 * 3600)
self.assertEqual(scheduler._convert_time('01-03'), 86400 + 3 * 3600)
with self.assertLogs(scheduler.logger, logging.WARNING):
scheduler._parse_joblist_output(0, TEXT_SQUEUE_TO_TEST, 'error message') # pylint: disable=protected-access


@pytest.mark.parametrize(
'value,expected', [('2', 2 * 60), ('02', 2 * 60), ('02:3', 2 * 60 + 3), ('02:03', 2 * 60 + 3),
('1:02:03', 3600 + 2 * 60 + 3), ('01:02:03', 3600 + 2 * 60 + 3), ('1-3', 86400 + 3 * 3600),
('01-3', 86400 + 3 * 3600), ('01-03', 86400 + 3 * 3600), ('1-3:5', 86400 + 3 * 3600 + 5 * 60),
('01-3:05', 86400 + 3 * 3600 + 5 * 60), ('01-03:05', 86400 + 3 * 3600 + 5 * 60),
('1-3:5:7', 86400 + 3 * 3600 + 5 * 60 + 7), ('01-3:05:7', 86400 + 3 * 3600 + 5 * 60 + 7),
('01-03:05:07', 86400 + 3 * 3600 + 5 * 60 + 7), ('UNLIMITED', 2**31 - 1), ('NOT_SET', None)]
)
def test_time_conversion(value, expected):
"""
Test conversion of (relative) times.
self.assertEqual(scheduler._convert_time('1-3:5'), 86400 + 3 * 3600 + 5 * 60)
self.assertEqual(scheduler._convert_time('01-3:05'), 86400 + 3 * 3600 + 5 * 60)
self.assertEqual(scheduler._convert_time('01-03:05'), 86400 + 3 * 3600 + 5 * 60)
From docs, acceptable time formats include
"minutes", "minutes:seconds", "hours:minutes:seconds",
"days-hours", "days-hours:minutes" and "days-hours:minutes:seconds".
"""
# pylint: disable=protected-access
scheduler = SlurmScheduler()
assert scheduler._convert_time(value) == expected

self.assertEqual(scheduler._convert_time('1-3:5:7'), 86400 + 3 * 3600 + 5 * 60 + 7)
self.assertEqual(scheduler._convert_time('01-3:05:7'), 86400 + 3 * 3600 + 5 * 60 + 7)
self.assertEqual(scheduler._convert_time('01-03:05:07'), 86400 + 3 * 3600 + 5 * 60 + 7)

self.assertEqual(scheduler._convert_time('UNLIMITED'), 2**31 - 1)
self.assertEqual(scheduler._convert_time('NOT_SET'), None)
def test_time_conversion_errors(caplog):
"""Test conversion of (relative) times for bad inputs."""
# pylint: disable=protected-access
scheduler = SlurmScheduler()

# Disable logging to avoid excessive output during test
logging.disable(logging.ERROR)
with self.assertRaises(ValueError):
# Disable logging to avoid excessive output during test
with caplog.at_level(logging.CRITICAL):
with pytest.raises(ValueError, match='Unrecognized format for time string.'):
# Empty string not valid
scheduler._convert_time('')
with self.assertRaises(ValueError):
with pytest.raises(ValueError, match='Unrecognized format for time string.'):
# there should be something after the dash
scheduler._convert_time('1-')
with self.assertRaises(ValueError):
with pytest.raises(ValueError, match='Unrecognized format for time string.'):
# there should be something after the dash
# there cannot be a dash after the colons
scheduler._convert_time('1:2-3')
# Reset logging level
logging.disable(logging.NOTSET)


class TestSubmitScript(unittest.TestCase):
class TestSubmitScript:
"""Test submit script generation by SLURM scheduler plugin."""

def test_submit_script(self):
Expand All @@ -194,13 +213,13 @@ def test_submit_script(self):

submit_script_text = scheduler.get_submit_script(job_tmpl)

self.assertTrue(submit_script_text.startswith('#!/bin/bash'))
assert submit_script_text.startswith('#!/bin/bash')

self.assertTrue('#SBATCH --no-requeue' in submit_script_text)
self.assertTrue('#SBATCH --time=1-00:00:00' in submit_script_text)
self.assertTrue('#SBATCH --nodes=1' in submit_script_text)
assert '#SBATCH --no-requeue' in submit_script_text
assert '#SBATCH --time=1-00:00:00' in submit_script_text
assert '#SBATCH --nodes=1' in submit_script_text

self.assertTrue("'mpirun' '-np' '23' 'pw.x' '-npool' '1' < 'aiida.in'" in submit_script_text)
assert "'mpirun' '-np' '23' 'pw.x' '-npool' '1' < 'aiida.in'" in submit_script_text

def test_submit_script_bad_shebang(self):
"""Test that first line of submit script is as expected."""
Expand All @@ -225,7 +244,7 @@ def test_submit_script_bad_shebang(self):
submit_script_text = scheduler.get_submit_script(job_tmpl)

# This tests if the implementation correctly chooses the default:
self.assertEqual(submit_script_text.split('\n')[0], expected_first_line)
assert submit_script_text.split('\n')[0] == expected_first_line

def test_submit_script_with_num_cores_per_machine(self): # pylint: disable=invalid-name
"""
Expand All @@ -252,13 +271,13 @@ def test_submit_script_with_num_cores_per_machine(self): # pylint: disable=inva

submit_script_text = scheduler.get_submit_script(job_tmpl)

self.assertTrue('#SBATCH --no-requeue' in submit_script_text)
self.assertTrue('#SBATCH --time=1-00:00:00' in submit_script_text)
self.assertTrue('#SBATCH --nodes=1' in submit_script_text)
self.assertTrue('#SBATCH --ntasks-per-node=2' in submit_script_text)
self.assertTrue('#SBATCH --cpus-per-task=12' in submit_script_text)
assert '#SBATCH --no-requeue' in submit_script_text
assert '#SBATCH --time=1-00:00:00' in submit_script_text
assert '#SBATCH --nodes=1' in submit_script_text
assert '#SBATCH --ntasks-per-node=2' in submit_script_text
assert '#SBATCH --cpus-per-task=12' in submit_script_text

self.assertTrue("'mpirun' '-np' '23' 'pw.x' '-npool' '1' < 'aiida.in'" in submit_script_text)
assert "'mpirun' '-np' '23' 'pw.x' '-npool' '1' < 'aiida.in'" in submit_script_text

def test_submit_script_with_num_cores_per_mpiproc(self): # pylint: disable=invalid-name
"""
Expand All @@ -284,13 +303,13 @@ def test_submit_script_with_num_cores_per_mpiproc(self): # pylint: disable=inva

submit_script_text = scheduler.get_submit_script(job_tmpl)

self.assertTrue('#SBATCH --no-requeue' in submit_script_text)
self.assertTrue('#SBATCH --time=1-00:00:00' in submit_script_text)
self.assertTrue('#SBATCH --nodes=1' in submit_script_text)
self.assertTrue('#SBATCH --ntasks-per-node=1' in submit_script_text)
self.assertTrue('#SBATCH --cpus-per-task=24' in submit_script_text)
assert '#SBATCH --no-requeue' in submit_script_text
assert '#SBATCH --time=1-00:00:00' in submit_script_text
assert '#SBATCH --nodes=1' in submit_script_text
assert '#SBATCH --ntasks-per-node=1' in submit_script_text
assert '#SBATCH --cpus-per-task=24' in submit_script_text

self.assertTrue("'mpirun' '-np' '23' 'pw.x' '-npool' '1' < 'aiida.in'" in submit_script_text)
assert "'mpirun' '-np' '23' 'pw.x' '-npool' '1' < 'aiida.in'" in submit_script_text

def test_submit_script_with_num_cores_per_machine_and_mpiproc1(self): # pylint: disable=invalid-name
"""
Expand Down Expand Up @@ -319,13 +338,13 @@ def test_submit_script_with_num_cores_per_machine_and_mpiproc1(self): # pylint:

submit_script_text = scheduler.get_submit_script(job_tmpl)

self.assertTrue('#SBATCH --no-requeue' in submit_script_text)
self.assertTrue('#SBATCH --time=1-00:00:00' in submit_script_text)
self.assertTrue('#SBATCH --nodes=1' in submit_script_text)
self.assertTrue('#SBATCH --ntasks-per-node=1' in submit_script_text)
self.assertTrue('#SBATCH --cpus-per-task=24' in submit_script_text)
assert '#SBATCH --no-requeue' in submit_script_text
assert '#SBATCH --time=1-00:00:00' in submit_script_text
assert '#SBATCH --nodes=1' in submit_script_text
assert '#SBATCH --ntasks-per-node=1' in submit_script_text
assert '#SBATCH --cpus-per-task=24' in submit_script_text

self.assertTrue("'mpirun' '-np' '23' 'pw.x' '-npool' '1' < 'aiida.in'" in submit_script_text)
assert "'mpirun' '-np' '23' 'pw.x' '-npool' '1' < 'aiida.in'" in submit_script_text

def test_submit_script_with_num_cores_per_machine_and_mpiproc2(self): # pylint: disable=invalid-name
"""
Expand All @@ -340,13 +359,13 @@ def test_submit_script_with_num_cores_per_machine_and_mpiproc2(self): # pylint:
scheduler = SlurmScheduler()

job_tmpl = JobTemplate()
with self.assertRaises(ValueError):
with pytest.raises(ValueError, match='`num_cores_per_machine` must be equal to'):
job_tmpl.job_resource = scheduler.create_job_resource(
num_machines=1, num_mpiprocs_per_machine=1, num_cores_per_machine=24, num_cores_per_mpiproc=23
)


class TestJoblistCommand(unittest.TestCase):
class TestJoblistCommand:
"""
Tests of the issued squeue command.
"""
Expand All @@ -356,15 +375,15 @@ def test_joblist_single(self):
scheduler = SlurmScheduler()

command = scheduler._get_joblist_command(jobs=['123']) # pylint: disable=protected-access
self.assertIn('123,123', command)
assert '123,123' in command

def test_joblist_multi(self):
"""Test that asking for multiple jobs does not result in duplications."""
scheduler = SlurmScheduler()

command = scheduler._get_joblist_command(jobs=['123', '456']) # pylint: disable=protected-access
self.assertIn('123,456', command)
self.assertNotIn('456,456', command)
assert '123,456' in command
assert '456,456' not in command


def test_parse_out_of_memory():
Expand Down

0 comments on commit c42a86b

Please sign in to comment.