diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 419eca0..38076ba 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -2,7 +2,7 @@ exclude: '^docs/conf.py' repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.2.0 + rev: v4.3.0 hooks: - id: trailing-whitespace - id: check-added-large-files diff --git a/src/xtal2png/core.py b/src/xtal2png/core.py index 6a0eb07..13a5ebc 100644 --- a/src/xtal2png/core.py +++ b/src/xtal2png/core.py @@ -17,6 +17,7 @@ from pymatgen.core.lattice import Lattice from pymatgen.core.structure import Structure from pymatgen.io.cif import CifWriter +from pymatgen.symmetry.analyzer import SpacegroupAnalyzer from xtal2png import __version__ from xtal2png.utils.data import dummy_structures, rgb_scaler, rgb_unscaler @@ -104,6 +105,14 @@ class XtalConverter: save_dir : Union[str, 'PathLike[str]'] Directory to save PNG files via ``func:xtal2png``, by default path.join("data", "interim") + symprec : float, optional + The symmetry precision to use when decoding `pymatgen` structures via + ``func:pymatgen.symmetry.analyzer.SpaceGroupAnalyzer.get_refined_structure``. By + default 0.1. + angle_tolerance : Union[float, int], optional + The angle tolerance (degrees) to use when decoding `pymatgen` structures via + ``func:pymatgen.symmetry.analyzer.SpaceGroupAnalyzer.get_refined_structure``. By + default 5.0. Examples -------- @@ -125,6 +134,8 @@ def __init__( distance_range: Tuple[float, float] = (0.0, 18.0), max_sites: int = 52, save_dir: Union[str, "PathLike[str]"] = path.join("data", "preprocessed"), + symprec: float = 0.1, + angle_tolerance: float = 5.0, ): """Instantiate an XtalConverter object with desired ranges and ``max_sites``.""" self.atom_range = atom_range @@ -138,6 +149,8 @@ def __init__( self.distance_range = distance_range self.max_sites = max_sites self.save_dir = save_dir + self.symprec = symprec + self.angle_tolerance = angle_tolerance Path(save_dir).mkdir(exist_ok=True, parents=True) @@ -283,7 +296,9 @@ def png2xtal( if save: for s in S: fpath = path.join(self.save_dir, construct_save_name(s) + ".cif") - CifWriter(s).write_file(fpath) + CifWriter( + s, symprec=self.symprec, angle_tolerance=self.angle_tolerance + ).write_file(fpath) return S @@ -625,6 +640,10 @@ def arrays_to_structures(self, data: np.ndarray): a=a, b=b, c=c, alpha=alpha, beta=beta, gamma=gamma ) structure = Structure(lattice, at, fr) + spa = SpacegroupAnalyzer( + structure, symprec=self.symprec, angle_tolerance=self.angle_tolerance + ) + structure = spa.get_refined_structure() S.append(structure) return S diff --git a/tests/xtal2png_test.py b/tests/xtal2png_test.py index 3187d3b..d6b6e36 100644 --- a/tests/xtal2png_test.py +++ b/tests/xtal2png_test.py @@ -2,8 +2,11 @@ # (test), and back to crystal Structure (test) +from warnings import warn + import plotly.express as px -from numpy.testing import assert_allclose +from numpy.testing import assert_allclose, assert_equal +from pymatgen.analysis.structure_matcher import ElementComparator, StructureMatcher from xtal2png.core import XtalConverter from xtal2png.utils.data import ( @@ -20,20 +23,39 @@ def assert_structures_approximate_match(example_structures, structures): - for s, structure in zip(example_structures, structures): + for i, (s, structure) in enumerate(zip(example_structures, structures)): + # d = np.linalg.norm(s._lattice.abc) + # sm = StructureMatcher( + # ltol=rgb_loose_tol * d, + # stol=rgb_loose_tol * d, + # angle_tol=rgb_loose_tol * 180, + # comparator=ElementComparator(), + # ) + sm = StructureMatcher(comparator=ElementComparator()) + is_match = sm.fit(s, structure) + if not is_match: + warn( + f"{i}-th original and decoded structures do not match according to StructureMatcher(comparator=ElementComparator()).fit(s, structure).\n\nOriginal (s): {s}\n\nDecoded (structure): {structure}" # noqa: E501 + ) + + sm = StructureMatcher(primitive_cell=False, comparator=ElementComparator()) + s2 = sm.get_s2_like_s1(s, structure) + a_check = s._lattice.a b_check = s._lattice.b c_check = s._lattice.c angles_check = s._lattice.angles atomic_numbers_check = s.atomic_numbers frac_coords_check = s.frac_coords + space_group_check = s.get_space_group_info()[1] - latt_a = structure._lattice.a - latt_b = structure._lattice.b - latt_c = structure._lattice.c - angles = structure._lattice.angles - atomic_numbers = structure.atomic_numbers - frac_coords = structure.frac_coords + latt_a = s2._lattice.a + latt_b = s2._lattice.b + latt_c = s2._lattice.c + angles = s2._lattice.angles + atomic_numbers = s2.atomic_numbers + frac_coords = s2.frac_coords + space_group = s.get_space_group_info()[1] assert_allclose( a_check, @@ -78,6 +100,12 @@ def assert_structures_approximate_match(example_structures, structures): err_msg="atomic numbers not all close", ) + assert_equal( + space_group_check, + space_group, + err_msg=f"space groups do not match. Original: {space_group_check}. Decoded: {space_group}.", # noqa: E501 + ) + def test_structures_to_arrays(): xc = XtalConverter()