Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: verdi data trajectory show #5394

Merged
merged 30 commits into from
Jul 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
00d0b20
fix: verdi data trajectory show
ltalirz Feb 25, 2022
612ad0d
Fix the tests
sphuber Jul 7, 2023
9f548cb
Increase test coverage
sphuber Jul 7, 2023
3888ab8
Test
sphuber Jul 7, 2023
1bcfa4a
skip all
sphuber Jul 7, 2023
1e4971f
Testing
sphuber Jul 7, 2023
b22e995
Another try
sphuber Jul 7, 2023
2e19aaf
Dont block
sphuber Jul 7, 2023
5671e54
gc.collect()
sphuber Jul 7, 2023
532d332
try just jmol
sphuber Jul 8, 2023
8e14baf
Try only mpl_pos
sphuber Jul 8, 2023
097fcd0
Comment out `run_cli_command`
sphuber Jul 8, 2023
72de06b
Comment out data creation
sphuber Jul 8, 2023
3c9e431
COmment out entire test
sphuber Jul 8, 2023
eba1f14
Start building test step by step
sphuber Jul 8, 2023
bf1566b
Add imports
sphuber Jul 8, 2023
c4bbe52
Use skip inside parametrization
sphuber Jul 8, 2023
d2b248e
get rid of autofixture
sphuber Jul 8, 2023
d591bee
comment out matplotlib import and mpl_heatmap parameter
sphuber Jul 8, 2023
31289fb
Reenable matplotlib import
sphuber Jul 8, 2023
d457ce8
Add the tests, but without matplotlib import
sphuber Jul 8, 2023
7ae68c3
Add the mpl_pos param test but without importing/mocking matplotlib
sphuber Jul 8, 2023
cf811fe
Add monkeypatch for `matplotlib.pyplot` requires import again though
sphuber Jul 8, 2023
4bd01c4
Scope monkeypatches
sphuber Jul 8, 2023
ccae92d
Import matplotlib and shit dies
sphuber Jul 8, 2023
b325323
Monkeypatch all the things
sphuber Jul 8, 2023
4bac9f8
Revert most changes and keep essential
sphuber Jul 8, 2023
d79ac8c
Prove matplotlib is the culprit
sphuber Jul 8, 2023
1a47b90
Remove import and add comment to explain the madness
sphuber Jul 8, 2023
79c7f5f
Remove system requirements and parametrize skips
sphuber Jul 8, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 22 additions & 52 deletions aiida/cmdline/commands/cmd_data/cmd_show.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,96 +12,66 @@
"""
import pathlib

import click

from aiida.cmdline.params import options
from aiida.cmdline.params.options.multivalue import MultipleValueOption
from aiida.cmdline.utils import echo
from aiida.common.exceptions import MultipleObjectsError

SHOW_OPTIONS = [
options.TRAJECTORY_INDEX(),
options.WITH_ELEMENTS(),
click.option('-c', '--contour', type=click.FLOAT, cls=MultipleValueOption, default=None, help='Isovalues to plot'),
click.option(
'--sampling-stepsize',
type=click.INT,
default=None,
help='Sample positions in plot every sampling_stepsize timestep'
),
click.option(
'--stepsize',
type=click.INT,
default=None,
help='The stepsize for the trajectory, set it higher to reduce number of points'
),
click.option('--mintime', type=click.INT, default=None, help='The time to plot from'),
click.option('--maxtime', type=click.INT, default=None, help='The time to plot to'),
click.option('--indices', type=click.INT, cls=MultipleValueOption, default=None, help='Show only these indices'),
click.option(
'--dont-block', 'block', is_flag=True, default=True, help="Don't block interpreter when showing plot."
),
]


def show_options(func):
for option in reversed(SHOW_OPTIONS):
func = option(func)

return func


def _show_jmol(exec_name, trajectory_list, **kwargs):

def has_executable(exec_name):
"""
:return: True if executable can be found in PATH, False otherwise.
"""
import shutil
return shutil.which(exec_name) is not None


def _show_jmol(exec_name, trajectory_list, **_kwargs):
"""
Plugin for jmol
"""
import subprocess
import tempfile

if not has_executable(exec_name):
echo.echo_critical(f"No executable '{exec_name}' found. Add to the path, or try with an absolute path.")

# pylint: disable=protected-access
with tempfile.NamedTemporaryFile(mode='w+b') as handle:
for trajectory in trajectory_list:
handle.write(trajectory._exportcontent('cif', **kwargs)[0])
handle.write(trajectory._exportcontent('cif')[0])
handle.flush()

try:
subprocess.check_output([exec_name, handle.name])
except subprocess.CalledProcessError:
# The program died: just print a message
echo.echo_error(f'the call to {exec_name} ended with an error.')
except OSError as err:
if err.errno == 2:
echo.echo_critical(f"No executable '{exec_name}' found. Add to the path, or try with an absolute path.")
else:
raise


def _show_xcrysden(exec_name, object_list, **kwargs):
def _show_xcrysden(exec_name, trajectory_list, **_kwargs):
"""
Plugin for xcrysden
"""
import subprocess
import tempfile

if len(object_list) > 1:
if len(trajectory_list) > 1:
raise MultipleObjectsError('Visualization of multiple trajectories is not implemented')
obj = object_list[0]
obj = trajectory_list[0]

if not has_executable(exec_name):
echo.echo_critical(f"No executable '{exec_name}' found.")

# pylint: disable=protected-access
with tempfile.NamedTemporaryFile(mode='w+b', suffix='.xsf') as tmpf:
tmpf.write(obj._exportcontent('xsf', **kwargs)[0])

tmpf.write(obj._exportcontent('xsf')[0])
tmpf.flush()

try:
subprocess.check_output([exec_name, '--xsf', tmpf.name])
except subprocess.CalledProcessError:
# The program died: just print a message
echo.echo_error(f'the call to {exec_name} ended with an error.')
except OSError as err:
if err.errno == 2:
echo.echo_critical(f"No executable '{exec_name}' found. Add to the path, or try with an absolute path.")
else:
raise


# pylint: disable=unused-argument
Expand Down
29 changes: 25 additions & 4 deletions aiida/cmdline/commands/cmd_data/cmd_trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from aiida.cmdline.commands.cmd_data import cmd_show, verdi_data
from aiida.cmdline.commands.cmd_data.cmd_export import data_export, export_options
from aiida.cmdline.commands.cmd_data.cmd_list import data_list, list_options
from aiida.cmdline.commands.cmd_data.cmd_show import show_options
from aiida.cmdline.params import arguments, options, types
from aiida.cmdline.utils import decorators, echo

Expand Down Expand Up @@ -66,16 +65,38 @@ def trajectory_list(raw, past_days, groups, all_users):
@trajectory.command('show')
@arguments.DATA(type=types.DataParamType(sub_classes=('aiida.data:core.array.trajectory',)))
@options.VISUALIZATION_FORMAT(type=click.Choice(VISUALIZATION_FORMATS), default='jmol')
@show_options
@options.TRAJECTORY_INDEX()
@options.WITH_ELEMENTS()
@click.option(
'-c', '--contour', type=click.FLOAT, cls=options.MultipleValueOption, default=None, help='Isovalues to plot'
)
@click.option(
'--sampling-stepsize',
type=click.INT,
default=None,
help='Sample positions in plot every sampling_stepsize timestep'
)
@click.option(
'--stepsize',
type=click.INT,
default=None,
help='The stepsize for the trajectory, set it higher to reduce number of points'
)
@click.option('--mintime', type=click.INT, default=None, help='The time to plot from')
@click.option('--maxtime', type=click.INT, default=None, help='The time to plot to')
@click.option(
'--indices', type=click.INT, cls=options.MultipleValueOption, default=None, help='Show only these indices'
)
@click.option('--dont-block', 'block', is_flag=True, default=True, help="Don't block interpreter when showing plot.")
@decorators.with_dbenv()
def trajectory_show(data, fmt):
def trajectory_show(data, fmt, **kwargs):
"""Visualize a trajectory."""
try:
show_function = getattr(cmd_show, f'_show_{fmt}')
except AttributeError:
echo.echo_critical(f'visualization format {fmt} is not supported')

show_function(fmt, data)
show_function(exec_name=fmt, trajectory_list=data, **kwargs)


@trajectory.command('export')
Expand Down
19 changes: 7 additions & 12 deletions aiida/orm/nodes/data/array/trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,8 +652,6 @@ def show_mpl_pos(self, **kwargs): # pylint: disable=too-many-locals
from ase.data.colors import cpk_colors as colors
else:
raise ValueError(f'Unknown color spec {colors}')
if kwargs:
raise ValueError(f'Unrecognized keyword {kwargs.keys()}')

if element_list is None:
# If not all elements are allowed
Expand Down Expand Up @@ -703,12 +701,8 @@ def show_mpl_heatmap(self, **kwargs): # pylint: disable=invalid-name,too-many-a
from mayavi import mlab
except ImportError:
raise ImportError(
'Unable to import the mayavi package, that is required to'
'use the plotting feature you requested. '
'Please install it first and then call this command again '
'(note that the installation of mayavi is quite complicated '
'and requires that you already installed the python numpy '
'package, as well as the vtk package'
'The plotting feature you requested requires the mayavi package.'
'Try `pip install mayavi` or consult the documentation.'
)
from ase.data import atomic_numbers
from ase.data.colors import jmol_colors
Expand Down Expand Up @@ -847,7 +841,7 @@ def plot_positions_XYZ( # pylint: disable=too-many-arguments,too-many-locals,in
dont_block=False,
mintime=None,
maxtime=None,
label_sparsity=10):
n_labels=10):
"""
Plot with matplotlib the positions of the coordinates of the atoms
over time for a trajectory
Expand All @@ -862,14 +856,14 @@ def plot_positions_XYZ( # pylint: disable=too-many-arguments,too-many-locals,in
:param dont_block: passed to plt.show() as ``block=not dont_block``
:param mintime: if specified, cut the time axis at the specified min value
:param maxtime: if specified, cut the time axis at the specified max value
:param label_sparsity: how often to put a label with the pair (t, coord)
:param n_labels: how many labels (t, coord) to put
"""
from matplotlib import pyplot as plt
from matplotlib.gridspec import GridSpec
import numpy as np

tlim = [times[0], times[-1]]
index_range = [0, len(times)]
index_range = [0, len(times) - 1]
if mintime is not None:
tlim[0] = mintime
index_range[0] = np.argmax(times > mintime)
Expand All @@ -896,7 +890,8 @@ def plot_positions_XYZ( # pylint: disable=too-many-arguments,too-many-locals,in
plt.ylabel(r'Z Position $\left[{}\right]$'.format(positions_unit))
plt.xlabel(f'Time [{times_unit}]')
plt.xlim(*tlim)
sparse_indices = np.linspace(*index_range, num=label_sparsity, dtype=int)
n_labels = np.minimum(n_labels, len(times)) # don't need more labels than times
sparse_indices = np.linspace(*index_range, num=n_labels, dtype=int)

for index, traj in enumerate(trajectories):
if index not in indices_to_show:
Expand Down
52 changes: 52 additions & 0 deletions tests/cmdline/commands/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
cmd_cif,
cmd_dict,
cmd_remote,
cmd_show,
cmd_singlefile,
cmd_structure,
cmd_trajectory,
Expand All @@ -37,6 +38,15 @@
from tests.static import STATIC_DIR


def has_mayavi() -> bool:
"""Return whether the ``mayavi`` module can be imported."""
try:
import mayavi # pylint: disable=unused-import
except ImportError:
return False
return True


class DummyVerdiDataExportable:
"""Test exportable data objects."""

Expand Down Expand Up @@ -517,6 +527,48 @@ def test_export(self, output_flag, tmp_path):
new_supported_formats = list(cmd_trajectory.EXPORT_FORMATS)
self.data_export_test(TrajectoryData, self.pks, new_supported_formats, output_flag, tmp_path)

@pytest.mark.parametrize(
'fmt', (
pytest.param(
'jmol', marks=pytest.mark.skipif(not cmd_show.has_executable('jmol'), reason='No jmol executable.')
),
pytest.param(
'xcrysden',
marks=pytest.mark.skipif(not cmd_show.has_executable('xcrysden'), reason='No xcrysden executable.')
),
pytest.param(
'mpl_heatmap', marks=pytest.mark.skipif(not has_mayavi(), reason='Package `mayavi` not installed.')
), pytest.param('mpl_pos')
)
)
def test_trajectoryshow(self, fmt, monkeypatch, run_cli_command):
"""Test showing the trajectory data in different formats"""
trajectory_pk = self.pks[DummyVerdiDataListable.NODE_ID_STR]
options = ['--format', fmt, str(trajectory_pk), '--dont-block']

def mock_check_output(options):
assert isinstance(options, list)
assert options[0] == fmt

if fmt in ['jmol', 'xcrysden']:
# This is called by the ``_show_jmol`` and ``_show_xcrysden`` implementations. We want to test just the
# function but not the actual commands through a sub process. Note that this mock needs to happen only for
# these specific formats, because ``matplotlib`` used in the others _also_ calls ``subprocess.check_output``
monkeypatch.setattr(sp, 'check_output', mock_check_output)

if fmt in ['mpl_pos']:
# This has to be mocked because ``plot_positions_xyz`` imports ``matplotlib.pyplot`` and for some completely
# unknown reason, causes ``tests/storage/psql_dos/test_backend.py::test_unload_profile`` to fail. For some
# reason, merely importing ``matplotlib`` (even here directly in the test) will cause that test to claim
# that there still is something holding on to a reference of an sqlalchemy session that it keeps track of
# in the ``sqlalchemy.orm.session._sessions`` weak ref dictionary. Since it is impossible to figure out why
# the hell importing matplotlib would interact with sqlalchemy sessions, the function that does the import
# is simply mocked out for now.
from aiida.orm.nodes.data.array import trajectory
monkeypatch.setattr(trajectory, 'plot_positions_XYZ', lambda *args, **kwargs: None)

run_cli_command(cmd_trajectory.trajectory_show, options, use_subprocess=False)


class TestVerdiDataStructure(DummyVerdiDataListable, DummyVerdiDataExportable):
"""Test verdi data core.structure."""
Expand Down
Loading