Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

save and load, equations, and plotting #69

Merged
merged 6 commits into from
May 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,152 changes: 1,152 additions & 0 deletions reports/equations.nb

Large diffs are not rendered by default.

13 changes: 13 additions & 0 deletions reports/equations/coverage.tex
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
\begin{equation} \label{eq:coverage}
\frac{\sum _{i=1}^{n_{\text{test}}} \sum _{j=1}^{n_{\text{gen}}} \left(
\begin{array}{cc}
\{ &
\begin{array}{cc}
1 & d\left(s_{\text{test},i},s_{\text{gen},j}\right)\leq \text{tol} \\
0 & d\left(s_{\text{test},i},s_{\text{gen},j}\right)>\text{tol} \\
\end{array}
\\
\end{array}
\right)}{n_{\text{test}}}
\end{equation}
where $n_{\text{test}}$, $n_{\text{gen}}$, $d$, $s_{\text{test},i}$, $s_{\text{gen},j}$, and $\text{tol}$ represent number of structures in the test set, number of structures in the generated set, crystallographic distance according to \texttt{pymatgen.analysis.structure\_matcher.StructureMatcher}, $i$-th structure of the test set, $j$-th structure of the generated set, and a tolerance threshold, respectively.
13 changes: 13 additions & 0 deletions reports/equations/novelty.tex
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
\begin{equation} \label{eq:novelty}
1-\frac{\sum _{i=1}^{n_{\text{train}}} \sum _{j=1}^{n_{\text{gen}}} \left(
\begin{array}{cc}
\{ &
\begin{array}{cc}
1 & d\left(s_{\text{train},i},s_{\text{gen},j}\right)\leq \text{tol} \\
0 & d\left(s_{\text{train},i},s_{\text{gen},j}\right)>\text{tol} \\
\end{array}
\\
\end{array}
\right)}{n_{\text{gen}}}
\end{equation}
where $n_{\text{train}}$, $n_{\text{gen}}$, $d$, $s_{\text{train},i}$, $s_{\text{gen},j}$, and $\text{tol}$ represent number of structures in the training set, number of structures in the generated set, crystallographic distance according to \texttt{StructureMatcher} from \texttt{pymatgen.analysis.structure\_matcher}, $i$-th structure of the training set, $j$-th structure of the generated set, and a tolerance threshold, respectively.
14 changes: 14 additions & 0 deletions reports/equations/uniqueness.tex
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
\begin{equation} \label{eq:uniqueness}
1-\frac{\sum _{i=1}^{n_{\text{gen}}} \sum _{j=1}^{n_{\text{gen}}} \left(
\begin{array}{cc}
\{ &
\begin{array}{cc}
0 & i=j \\
1 & d\left(s_{\text{gen},i},s_{\text{gen},j}\right)\leq \text{tol}\land i\neq j \\
0 & d\left(s_{\text{gen},i},s_{\text{gen},j}\right)>\text{tol}\land i\neq j \\
\end{array}
\\
\end{array}
\right)}{n_{\text{gen}}^2-n_{\text{gen}}}
\end{equation}
where $n_{\text{gen}}$, $d$, $s_{\text{gen},i}$, $s_{\text{gen},j}$, and $\text{tol}$ represent number of structures in the generated set, crystallographic distance according to \texttt{StructureMatcher} from \texttt{pymatgen.analysis.structure\_matcher}, $i$-th structure of the generated set, $j$-th structure of the generated set, and a tolerance threshold, respectively.
4 changes: 4 additions & 0 deletions reports/equations/validity.tex
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
\begin{equation} \label{eq:validity}
1-\frac{w\left(\text{SG}_{\text{train}},\text{SG}_{\text{test}}\right)}{w\left(\text{SG}_{\text{train}},1\right)}
\end{equation}
where $\mathit{w}$, $\text{SG}_{\text{train}}$, and $\text{SG}_{\text{test}}$ represent Wasserstein distance, vector of space group numbers for the training data, and vector of space group numbers for the test data, respectively.
69 changes: 69 additions & 0 deletions reports/matbench-genmetrics.tex
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
\subsection{VALIDATION: Assessing Performance}
\label{sec:validation}
\textcolor{red}{****STERLING****}

While significant progress has been made to create standardized, easy-to-use benchmarks for molecular discovery \cite{brownGuacaMolBenchmarkingModels2019}, this remains a challenge for solid-state materials \cite{spekCheckCIFValidationALERTS2020, xie_crystal_2022, zhao_physics_2022}. To address this limitation, we propose \texttt{matbench-genmetrics}, an open-source Python library for benchmarking generative models for crystal structures. We incorporate benchmark datasets, splits, and metrics inspired by Crystal Diffusion Variational AutoEncoder (CDVAE) \cite{xieCrystalDiffusionVariational2021}. We provide our own benchmarks using time-series style cross-validation splits from Materials Project via our \texttt{mp-time-split} package and we will also incorporate an automated leaderboard and submission system and provide an easy-to-use example for users to prepare and submit benchmarks for new models.

Here we define the four metrics used in \texttt{matbench-genmetrics}: validity, coverage, novelty, and uniqueness (\cref{fig:matbench-genmetrics}).

\begin{figure}
\centering
\includegraphics[width=0.48\textwidth]{sections/figs/metrics.png}
\caption{The four metrics of \texttt{matbench-genmetrics} for assessing performance of materials generative models are validity, coverage, novelty, and uniqueness. Validity is the comparison of distribution characteristics (space group number) between the generated materials and the training and test sets. Coverage is the number of matches between the generated structures and a held-out test set. Novelty is a comparison between the generated and training structures. Finally, uniqueness is a measure of the number of repeats within the generated structures (i.e., comparing the set of generated structures to itself).}
\label{fig:matbench-genmetrics}
\end{figure}

We define validity as one minus the Wasserstein distance between distribution of space group numbers for train and generated structures divided by the distance of the dummy case between train and the space group number 1:

\begin{equation} \label{eq:validity}
1-\frac{w\left(\mathrm{SG}_{\mathrm{train}},\mathrm{SG}_{\mathrm{test}}\right)}{w\left(\mathrm{SG}_{\mathrm{train}},1\right)}
\end{equation}
where $w$, $\mathrm{SG}_{\mathrm{train}}$, and $\mathrm{SG}_{\mathrm{test}}$ represent Wasserstein distance, vector of space group numbers for the training data, and vector of space group numbers for the test data, respectively.

Coverage (``predict the future'') is given by the match counts between the held-out test structures and the generated structures divided by the number of test structures:

\begin{equation} \label{eq:coverage}
\frac{\sum _{i=1}^{n_{\text{test}}} \sum _{j=1}^{n_{\text{gen}}} \left(
\left\{
\begin{array}{cc}
1 & d\left(s_{\text{test},i},s_{\text{gen},j}\right)\leq \text{tol} \\
0 & d\left(s_{\text{test},i},s_{\text{gen},j}\right)>\text{tol} \\
\end{array}
\\
\right.
\right)}{n_{\text{test}}}
\end{equation}
where $n_{\text{test}}$, $n_{\text{gen}}$, $d$, $s_{\text{test},i}$, $s_{\text{gen},j}$, and $\text{tol}$ represent number of structures in the test set, number of structures in the generated set, crystallographic distance according to \texttt{StructureMatcher} from \texttt{pymatgen.analysis.structure\_matcher}, $i$-th structure of the test set, $j$-th structure of the generated set, and a tolerance threshold, respectively.

Novelty is given by one minus the match counts between train structures and generated structures divided by number of generated structures:

\begin{equation} \label{eq:novelty}
1-\frac{\sum _{i=1}^{n_{\text{train}}} \sum _{j=1}^{n_{\text{gen}}} \left(
\left\{
\begin{array}{cc}
1 & d\left(s_{\text{train},i},s_{\text{gen},j}\right)\leq \text{tol} \\
0 & d\left(s_{\text{train},i},s_{\text{gen},j}\right)>\text{tol} \\
\end{array}
\\
\right.
\right)}{n_{\text{gen}}}
\end{equation}
where $n_{\text{train}}$, $n_{\text{gen}}$, $d$, $s_{\text{train},i}$, $s_{\text{gen},j}$, and $\text{tol}$ represent number of structures in the training set, number of structures in the generated set, crystallographic distance according to \texttt{StructureMatcher} from \texttt{pymatgen.analysis.structure\_matcher}, $i$-th structure of the training set, $j$-th structure of the generated set, and a tolerance threshold, respectively.

Uniqueness is given by one minus the non-self-comparing match counts within generated structures divided by total possible number of non-self-comparing matches:

\begin{equation} \label{eq:uniqueness}
1-\frac{\sum _{i=1}^{n_{\text{gen}}} \sum _{j=1}^{n_{\text{gen}}} \left(
\left\{
\begin{array}{cc}
0 & i=j \\
1 & d\left(s_{\text{gen},i},s_{\text{gen},j}\right)\leq \text{tol}\land i\neq j \\
0 & d\left(s_{\text{gen},i},s_{\text{gen},j}\right)>\text{tol}\land i\neq j \\
\end{array}
\\
\right.
\right)}{n_{\text{gen}}^2-n_{\text{gen}}}
\end{equation}
where $n_{\text{gen}}$, $d$, $s_{\text{gen},i}$, $s_{\text{gen},j}$, and $\text{tol}$ represent number of structures in the generated set, crystallographic distance according to \texttt{StructureMatcher} from \texttt{pymatgen.analysis.structure\_matcher}, $i$-th structure of the generated set, $j$-th structure of the generated set, and a tolerance threshold, respectively.

While useful individually, these metrics can also be used as multi-criteria filtering. One reason this is important is because, standalone, the metrics can be ``hacked'' in some sense. For example, the novelty metric may be made perfect simply by generating a diverse set of nonsensical crystal structures. Likewise, the validity score may be bloated simply by passing in the list of training structures as the generated structures. To combat this, multiple criteria may be considered simultaneously: for example, requiring that a structure must be simultaneously novel, unique, and passing certain filtering criteria such as non-overlapping atoms, stoichiometry rules, or \texttt{checkCIF} criteria \cite{spekCheckCIFValidationALERTS2020}. Additional filters based on maching learning prediction models can be used for properties such as formation energy (i.e., must be negative), energy above hull, ICSD classification, and coordination number. Of particular interest is applying machine-learning based structural relaxation to the structures prior to filtering through universal interatomic potential models such as M3GNet \cite{chen_universal_2022}.
13 changes: 13 additions & 0 deletions scripts/load_imagen_pytorch_generated.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ def main():
from xtal2png import XtalConverter

from matbench_genmetrics.core import MPTSMetrics10, MPTSMetrics1000
from matbench_genmetrics.utils.plotting import plot_structures_2d

fold = 0
dummy = False
Expand All @@ -26,11 +27,23 @@ def main():

gen_structures = xc.png2xtal(gen_images)

# with open(path.join(data_dir, f"gen_structures_fold={fold}.pkl"), "wb") as f:
# pickle.dump(gen_structures, f)

# with open(path.join(data_dir, f"gen_structures_fold={fold}.pkl"), "rb") as f:
# gen_structures_loaded = pickle.load(f)

mptm.get_train_and_val_data(fold)
mptm.evaluate_and_record(fold, gen_structures)

print(mptm.recorded_metrics)

mptm.save(
path.join(data_dir, f"gen_metrics_fold={fold},epoch={checkpoint_epoch}.pkl")
)

fig, _ = plot_structures_2d(gen_structures, 6, 5)

return mptm


Expand Down
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ install_requires =
mp-time-split[pyxtal]
pystow
element-coder
pymatviz


[options.packages.find]
Expand Down
13 changes: 13 additions & 0 deletions src/matbench_genmetrics/core.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""Core functionality for matbench-genmetrics (generative materials benchmarking)"""
import argparse
import json
import logging
import pickle
import sys
from pathlib import Path
from typing import List, Optional
Expand Down Expand Up @@ -518,6 +520,17 @@ def evaluate_and_record(self, fold, gen_structures, test_pred_structures=None):
for metric, value in self.recorded_metrics[fold].items():
setattr(self, metric, value)

def save(self, fpath_stem):
with open(fpath_stem + ".pkl", "wb") as f:
pickle.dump(self, f)

with open(fpath_stem + ".json", "w") as fp:
json.dump(self.recorded_metrics, fp)

def load(self, fpath):
with open(fpath, "rb") as f:
return pickle.load(f)


class MPTSMetrics10(MPTSMetrics):
def __init__(self, dummy=False, verbose=True):
Expand Down
52 changes: 52 additions & 0 deletions src/matbench_genmetrics/utils/plotting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# from ase.visualize import view
# from pymatgen.io.ase import AseAtomsAdaptor
import matplotlib.pyplot as plt
import numpy as np
from pymatviz.structure_viz import plot_structure_2d

# def plot_structure_3d(structure):
# view(AseAtomsAdaptor.get_atoms(structure))


def plot_structures_2d(structures, nrows, ncols, seed=10, formula_as_title=True):
if len(structures) > nrows * ncols:
# get random structures
plot_structures = np.random.RandomState(seed=seed).choice(
structures, size=nrows * ncols, replace=False
)
else:
plot_structures = structures

fig, axes = plt.subplots(nrows, ncols)

for s, ax in zip(plot_structures, axes.flatten()):
plot_structure_2d(s, ax=ax)
if formula_as_title:
formula = s.composition.reduced_formula
if len(formula) > 15:
formula = formula[0:7] + ".." + formula[-7:]
ax.set_title(formula)

return fig, axes


def plot_images(images, nrows, ncols, seed=10, formula_as_title=True):
if len(images) > nrows * ncols:
# get random structures
plot_images = np.random.RandomState(seed=seed).choice(
images, size=nrows * ncols, replace=False
)
else:
plot_images = images

fig, axes = plt.subplots(nrows, ncols)

for s, ax in zip(plot_images, axes.flatten()):
ax.imshow(s)
if formula_as_title:
formula = s.composition.reduced_formula
if len(formula) > 15:
formula = formula[0:7] + ".." + formula[-7:]
ax.set_title(formula)

return fig, axes