From 3a11c9c750817b10500bc7376df7369536a9b090 Mon Sep 17 00:00:00 2001 From: richardjgowers Date: Sun, 9 Feb 2020 17:03:06 +0000 Subject: [PATCH] made parmed checking not require parmed import --- package/MDAnalysis/core/_get_readers.py | 21 +++++++++++++---- .../coordinates/test_parmed.py | 23 ++++++++++++++----- .../MDAnalysisTests/topology/test_parmed.py | 5 ++-- 3 files changed, 36 insertions(+), 13 deletions(-) diff --git a/package/MDAnalysis/core/_get_readers.py b/package/MDAnalysis/core/_get_readers.py index 0d21735f485..79e65a4c26d 100644 --- a/package/MDAnalysis/core/_get_readers.py +++ b/package/MDAnalysis/core/_get_readers.py @@ -32,6 +32,15 @@ from ..lib import util +def _is_parmed_object(thing): + module = inspect.getmodule(thing.__class__) + + if module is None: + return False + else: + return module.__name__.startswith('parmed') + + def get_reader_for(filename, format=None): """Return the appropriate trajectory reader class for `filename`. @@ -66,7 +75,7 @@ def get_reader_for(filename, format=None): :class:`~MDAnalysis.coordinates.memory.MemoryReader` is returned. - If `filename` is an MMTF object, :class:`~MDAnalysis.coordinates.MMTF.MMTFReader` is returned. - - If `filename` is a ParmEd Structure, + - If `filename` is a ParmEd Structure, :class:`~MDAnalysis.coordinates.ParmEd.ParmEdReader` is returned. - If `filename` is an iterable of filenames, :class:`~MDAnalysis.coordinates.chain.ChainReader` is returned. @@ -93,6 +102,8 @@ def get_reader_for(filename, format=None): elif isinstance(filename, mmtf.MMTFDecoder): # mmtf slurps mmtf object format = 'MMTF' + elif _is_parmed_object(filename): + format = 'PARMED' else: # else let the guessing begin! format = util.guess_format(filename) @@ -178,7 +189,7 @@ def get_writer_for(filename, format=None, multiframe=None): None) else: format = util.check_compressed_format(root, ext) - + if format == '': raise ValueError(( 'File format could not be guessed from {}, ' @@ -236,6 +247,8 @@ def get_parser_for(filename, format=None): if format is None: if isinstance(filename, mmtf.MMTFDecoder): format = 'mmtf' + elif _is_parmed_object(filename): + format = 'PARMED' else: format = util.guess_format(filename) format = format.upper() @@ -273,7 +286,7 @@ def get_converter_for(format): TypeError If no appropriate parser could be found. - + .. versionadded:: 0.21.0 """ try: @@ -281,4 +294,4 @@ def get_converter_for(format): except KeyError: errmsg = 'No converter found for {} format' raise_from(TypeError(errmsg.format(format)), None) - return writer \ No newline at end of file + return writer diff --git a/testsuite/MDAnalysisTests/coordinates/test_parmed.py b/testsuite/MDAnalysisTests/coordinates/test_parmed.py index 31a2d3871bc..319f43e7a87 100644 --- a/testsuite/MDAnalysisTests/coordinates/test_parmed.py +++ b/testsuite/MDAnalysisTests/coordinates/test_parmed.py @@ -46,10 +46,23 @@ pmd = pytest.importorskip('parmed') +from MDAnalysis.core._get_readers import _is_parmed_object + + +@pytest.mark.parametrize('thing,reference', [ + ('foo', False), + ([1,2,3], False), + (pmd.load_file(GRO), True), +]) +def test_is_parmed_object(thing, reference): + assert _is_parmed_object(thing) == reference + + + class TestParmEdReaderGRO: ref_filename = GRO - universe = mda.Universe(pmd.load_file(GRO), format='parmed') + universe = mda.Universe(pmd.load_file(GRO)) ref = mda.Universe(GRO) prec = 3 @@ -68,10 +81,8 @@ def test_coordinates(self): class BaseTestParmEdReader(_SingleFrameReader): - def setUp(self): - self.universe = mda.Universe(pmd.load_file(self.ref_filename), - format='parmed') + self.universe = mda.Universe(pmd.load_file(self.ref_filename)) self.ref = mda.Universe(self.ref_filename) self.prec = 3 @@ -135,7 +146,7 @@ def output(self, universe): @pytest.fixture(scope='class') def roundtrip(self, ref): - u = mda.Universe(ref, format='parmed') + u = mda.Universe(ref) return u.atoms.convert_to('PARMED') def test_equivalent_connectivity_counts(self, universe, output): @@ -232,7 +243,7 @@ class BaseTestParmEdConverterFromParmed(BaseTestParmEdConverter): @pytest.fixture(scope='class') def universe(self, ref): - return mda.Universe(ref, format='parmed') + return mda.Universe(ref) def test_equivalent_connectivity_counts(self, ref, output): for attr in ('atoms', 'bonds', 'angles', 'dihedrals', 'impropers', diff --git a/testsuite/MDAnalysisTests/topology/test_parmed.py b/testsuite/MDAnalysisTests/topology/test_parmed.py index 2886b4c108b..07c93436387 100644 --- a/testsuite/MDAnalysisTests/topology/test_parmed.py +++ b/testsuite/MDAnalysisTests/topology/test_parmed.py @@ -61,10 +61,10 @@ def filename(self): @pytest.fixture def universe(self, filename): - return mda.Universe(filename, format='parmed') + return mda.Universe(filename) def test_creates_universe(self, filename): - u = mda.Universe(filename, format='parmed') + u = mda.Universe(filename) assert isinstance(u, mda.Universe) def test_bonds_total_counts(self, top, filename): @@ -241,4 +241,3 @@ def test_dihedral_types(self, universe): )): assert dih.type[i].type.phi_k == phi_k assert dih.type[i].type.per == per - \ No newline at end of file