Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use pymatgen StructureMatcher as initial check and preprocesser before detailed matching #85

Merged
merged 8 commits into from
Jun 11, 2022
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ exclude: '^docs/conf.py'

repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.2.0
rev: v4.3.0
hooks:
- id: trailing-whitespace
- id: check-added-large-files
Expand Down
21 changes: 20 additions & 1 deletion src/xtal2png/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from pymatgen.core.lattice import Lattice
from pymatgen.core.structure import Structure
from pymatgen.io.cif import CifWriter
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer

from xtal2png import __version__
from xtal2png.utils.data import dummy_structures, rgb_scaler, rgb_unscaler
Expand Down Expand Up @@ -104,6 +105,14 @@ class XtalConverter:
save_dir : Union[str, 'PathLike[str]']
Directory to save PNG files via ``func:xtal2png``,
by default path.join("data", "interim")
symprec : float, optional
The symmetry precision to use when decoding `pymatgen` structures via
``func:pymatgen.symmetry.analyzer.SpaceGroupAnalyzer.get_refined_structure``. By
default 0.1.
angle_tolerance : Union[float, int], optional
The angle tolerance (degrees) to use when decoding `pymatgen` structures via
``func:pymatgen.symmetry.analyzer.SpaceGroupAnalyzer.get_refined_structure``. By
default 5.0.
Examples
--------
Expand All @@ -125,6 +134,8 @@ def __init__(
distance_range: Tuple[float, float] = (0.0, 18.0),
max_sites: int = 52,
save_dir: Union[str, "PathLike[str]"] = path.join("data", "preprocessed"),
symprec: float = 0.1,
angle_tolerance: float = 5.0,
):
"""Instantiate an XtalConverter object with desired ranges and ``max_sites``."""
self.atom_range = atom_range
Expand All @@ -138,6 +149,8 @@ def __init__(
self.distance_range = distance_range
self.max_sites = max_sites
self.save_dir = save_dir
self.symprec = symprec
self.angle_tolerance = angle_tolerance

Path(save_dir).mkdir(exist_ok=True, parents=True)

Expand Down Expand Up @@ -283,7 +296,9 @@ def png2xtal(
if save:
for s in S:
fpath = path.join(self.save_dir, construct_save_name(s) + ".cif")
CifWriter(s).write_file(fpath)
CifWriter(
s, symprec=self.symprec, angle_tolerance=self.angle_tolerance
).write_file(fpath)

return S

Expand Down Expand Up @@ -625,6 +640,10 @@ def arrays_to_structures(self, data: np.ndarray):
a=a, b=b, c=c, alpha=alpha, beta=beta, gamma=gamma
)
structure = Structure(lattice, at, fr)
spa = SpacegroupAnalyzer(
structure, symprec=self.symprec, angle_tolerance=self.angle_tolerance
)
structure = spa.get_refined_structure()
S.append(structure)

return S
Expand Down
44 changes: 36 additions & 8 deletions tests/xtal2png_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@
# (test), and back to crystal Structure (test)


from warnings import warn

import plotly.express as px
from numpy.testing import assert_allclose
from numpy.testing import assert_allclose, assert_equal
from pymatgen.analysis.structure_matcher import ElementComparator, StructureMatcher

from xtal2png.core import XtalConverter
from xtal2png.utils.data import (
Expand All @@ -20,20 +23,39 @@


def assert_structures_approximate_match(example_structures, structures):
for s, structure in zip(example_structures, structures):
for i, (s, structure) in enumerate(zip(example_structures, structures)):
# d = np.linalg.norm(s._lattice.abc)
# sm = StructureMatcher(
# ltol=rgb_loose_tol * d,
# stol=rgb_loose_tol * d,
# angle_tol=rgb_loose_tol * 180,
# comparator=ElementComparator(),
# )
sm = StructureMatcher(comparator=ElementComparator())
is_match = sm.fit(s, structure)
if not is_match:
warn(
f"{i}-th original and decoded structures do not match according to StructureMatcher(comparator=ElementComparator()).fit(s, structure).\n\nOriginal (s): {s}\n\nDecoded (structure): {structure}" # noqa: E501
)

sm = StructureMatcher(primitive_cell=False, comparator=ElementComparator())
s2 = sm.get_s2_like_s1(s, structure)

a_check = s._lattice.a
b_check = s._lattice.b
c_check = s._lattice.c
angles_check = s._lattice.angles
atomic_numbers_check = s.atomic_numbers
frac_coords_check = s.frac_coords
space_group_check = s.get_space_group_info()[1]

latt_a = structure._lattice.a
latt_b = structure._lattice.b
latt_c = structure._lattice.c
angles = structure._lattice.angles
atomic_numbers = structure.atomic_numbers
frac_coords = structure.frac_coords
latt_a = s2._lattice.a
latt_b = s2._lattice.b
latt_c = s2._lattice.c
angles = s2._lattice.angles
atomic_numbers = s2.atomic_numbers
frac_coords = s2.frac_coords
space_group = s.get_space_group_info()[1]

assert_allclose(
a_check,
Expand Down Expand Up @@ -78,6 +100,12 @@ def assert_structures_approximate_match(example_structures, structures):
err_msg="atomic numbers not all close",
)

assert_equal(
space_group_check,
space_group,
err_msg=f"space groups do not match. Original: {space_group_check}. Decoded: {space_group}.", # noqa: E501
)


def test_structures_to_arrays():
xc = XtalConverter()
Expand Down