Skip to content

Commit

Permalink
Merge pull request #47 from sparks-baird/fingerprint
Browse files Browse the repository at this point in the history
implement fingerprint functionality

Definitely still open to suggestions. Merging for now to get a new version out since it seems to be working. Once the metric gets settled down, should probably hardcode some tests for the final version of the fingerprinting #50
  • Loading branch information
sgbaird authored Aug 6, 2022
2 parents 594bab6 + 8a22721 commit b398bf2
Show file tree
Hide file tree
Showing 6 changed files with 536 additions and 151 deletions.
207 changes: 153 additions & 54 deletions notebooks/1.0-matbench-genmetrics-basic.ipynb

Large diffs are not rendered by default.

24 changes: 24 additions & 0 deletions scripts/fingerprint_snapshot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from os import path
from pathlib import Path

from matbench_genmetrics.utils.featurize import featurize_comp_struct

if __name__ == "__main__":
comp_name = "composition"
struct_name = "structure"
material_id_name = "material_id"

dummy = False
comp_fingerprints, struct_fingerprints = featurize_comp_struct(
dummy=dummy,
comp_name=comp_name,
struct_name=struct_name,
material_id_name=material_id_name,
keep_as_df=True,
)
data_dir = path.join("data", "interim")
if dummy:
data_dir = path.join(data_dir, "dummy")
Path(data_dir).mkdir(exist_ok=True, parents=True)
comp_fingerprints.to_csv(path.join(data_dir, "comp_fingerprints.csv"))
struct_fingerprints.to_csv(path.join(data_dir, "struct_fingerprints.csv"))
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ install_requires =
# torch
# torch-geometric
mp-time-split[pyxtal]
pystow


[options.packages.find]
Expand Down
171 changes: 161 additions & 10 deletions src/matbench_genmetrics/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,22 @@
import argparse
import logging
import sys
from pathlib import Path
from typing import List, Optional

import numpy as np
from mp_time_split.core import MPTimeSplit
from pymatgen.core.structure import Structure
from pystow import ensure_csv
from scipy.stats import wasserstein_distance

from matbench_genmetrics import __version__
from matbench_genmetrics.utils.match import ALLOWED_MATCH_TYPES, get_match_matrix
from matbench_genmetrics.utils.featurize import featurize_comp_struct
from matbench_genmetrics.utils.match import (
ALLOWED_MATCH_TYPES,
cdvae_cov_compstruct_match_matrix,
get_structure_match_matrix,
)

# causes pytest to fail (tests not found, DLL load error)
# from matbench_genmetrics.cdvae.metrics import RecEval, GenEval, OptEval
Expand Down Expand Up @@ -48,17 +55,41 @@ def fib(n):

IN_COLAB = "google.colab" in sys.modules

FULL_COMP_NAME = "comp_fingerprints.csv"
DUMMY_COMP_NAME = "dummy_comp_fingerprints.csv"
FULL_STRUCT_NAME = "struct_fingerprints.csv"
DUMMY_STRUCT_NAME = "dummy_struct_fingerprints.csv"

FULL_COMP_CHECKSUM_FROZEN = "0d714081a8f0bc53af84b0ce96d3536f"
DUMMY_COMP_CHECKSUM_FROZEN = "5630a3bfc7cbeac0cb3d7897b02aae9f"
FULL_STRUCT_CHECKSUM_FROZEN = "312a4a282c57d80aed19a07dd2760ad9"
DUMMY_STRUCT_CHECKSUM_FROZEN = "d402abc2ba383e6b18b24413bdd96a7e"

FULL_COMP_URL = "https://figshare.com/ndownloader/files/36581838"
DUMMY_COMP_URL = "https://figshare.com/ndownloader/files/36582174"
FULL_STRUCT_URL = "https://figshare.com/ndownloader/files/36581841"
DUMMY_STRUCT_URL = "https://figshare.com/ndownloader/files/36582177"


DATA_HOME = "matbench-genmetrics"


class GenMatcher(object):
def __init__(
self,
test_structures,
gen_structures: Optional[List[Structure]] = None,
test_comp_fingerprints: Optional[np.ndarray] = None,
gen_comp_fingerprints: Optional[np.ndarray] = None,
test_struct_fingerprints: Optional[np.ndarray] = None,
gen_struct_fingerprints: Optional[np.ndarray] = None,
verbose=True,
match_type="cdvae_coverage",
**match_kwargs,
) -> None:
self.test_structures = test_structures
self.test_comp_fingerprints = test_comp_fingerprints
self.test_struct_fingerprints = test_struct_fingerprints
self.verbose = verbose
assert (
match_type in ALLOWED_MATCH_TYPES
Expand All @@ -73,6 +104,28 @@ def __init__(
self.gen_structures = gen_structures
self.symmetric = False

# featurize test and/or gen structures if features not provided
if self.match_type == "cdvae_coverage":
if test_comp_fingerprints is None or test_struct_fingerprints is None:
(
self.test_comp_fingerprints,
self.test_struct_fingerprints,
) = featurize_comp_struct(self.test_structures)

if self.symmetric:
self.gen_comp_fingerprints, self.gen_struct_fingerprints = (
self.test_comp_fingerprints,
self.test_struct_fingerprints,
)
elif gen_comp_fingerprints is None or gen_struct_fingerprints is None:
(
self.gen_comp_fingerprints,
self.gen_struct_fingerprints,
) = featurize_comp_struct(self.gen_structures)
else:
self.gen_comp_fingerprints = gen_comp_fingerprints
self.gen_struct_fingerprints = gen_struct_fingerprints

self.num_test = len(self.test_structures)
self.num_gen = len(self.gen_structures)

Expand All @@ -83,14 +136,25 @@ def match_matrix(self):
if self._match_matrix is not None:
return self._match_matrix

match_matrix = get_match_matrix(
self.test_structures,
self.gen_structures,
match_type=self.match_type,
symmetric=self.symmetric,
verbose=self.verbose,
**self.match_kwargs,
)
if self.match_type == "StructureMatcher":
match_matrix = get_structure_match_matrix(
self.test_structures,
self.gen_structures,
match_type=self.match_type,
symmetric=self.symmetric,
verbose=self.verbose,
**self.match_kwargs,
)
elif self.match_type == "cdvae_coverage":
match_matrix = cdvae_cov_compstruct_match_matrix(
self.test_comp_fingerprints,
self.gen_comp_fingerprints,
self.test_struct_fingerprints,
self.gen_struct_fingerprints,
symmetric=self.symmetric,
verbose=self.verbose,
**self.match_kwargs,
)

self._match_matrix = match_matrix

Expand Down Expand Up @@ -140,6 +204,10 @@ def __init__(
train_structures,
test_structures,
gen_structures,
train_comp_fingerprints=None,
test_comp_fingerprints=None,
train_struct_fingerprints=None,
test_struct_fingerprints=None,
test_pred_structures=None,
verbose=True,
match_type="cdvae_coverage",
Expand All @@ -148,10 +216,20 @@ def __init__(
self.train_structures = train_structures
self.test_structures = test_structures
self.gen_structures = gen_structures
self.train_comp_fingerprints = train_comp_fingerprints
self.test_comp_fingerprints = test_comp_fingerprints
self.train_struct_fingerprints = train_struct_fingerprints
self.test_struct_fingerprints = test_struct_fingerprints
self.test_pred_structures = test_pred_structures
self.verbose = verbose
self.match_type = match_type
self.match_kwargs = match_kwargs

(
self.gen_comp_fingerprints,
self.gen_struct_fingerprints,
) = featurize_comp_struct(self.gen_structures)

self._cdvae_metrics = None
self._mpts_metrics = None

Expand Down Expand Up @@ -181,6 +259,7 @@ def __init__(
@property
def validity(self):
"""Scaled Wasserstein distance between real (train/test) and gen structures."""
# TODO: implement notion of compositional validity, since this is only structure
train_test_structures = self.train_structures + self.test_structures
train_test_spg = [ts.get_space_group_info()[1] for ts in train_test_structures]
gen_spg = [ts.get_space_group_info()[1] for ts in self.gen_structures]
Expand All @@ -193,6 +272,10 @@ def coverage(self):
self.coverage_matcher = GenMatcher(
self.test_structures,
self.gen_structures,
test_comp_fingerprints=self.test_comp_fingerprints,
test_struct_fingerprints=self.test_struct_fingerprints,
gen_comp_fingerprints=self.gen_comp_fingerprints,
gen_struct_fingerprints=self.gen_struct_fingerprints,
verbose=self.verbose,
match_type=self.match_type,
**self.match_kwargs,
Expand All @@ -205,6 +288,10 @@ def novelty(self):
self.similarity_matcher = GenMatcher(
self.train_structures,
self.gen_structures,
test_comp_fingerprints=self.train_comp_fingerprints,
test_struct_fingerprints=self.train_struct_fingerprints,
gen_comp_fingerprints=self.gen_comp_fingerprints,
gen_struct_fingerprints=self.gen_struct_fingerprints,
verbose=self.verbose,
match_type=self.match_type,
**self.match_kwargs,
Expand All @@ -220,6 +307,10 @@ def uniqueness(self):
self.commonality_matcher = GenMatcher(
self.gen_structures,
self.gen_structures,
test_comp_fingerprints=self.gen_comp_fingerprints,
test_struct_fingerprints=self.gen_struct_fingerprints,
gen_comp_fingerprints=self.gen_comp_fingerprints,
gen_struct_fingerprints=self.gen_struct_fingerprints,
verbose=self.verbose,
match_type=self.match_type,
**self.match_kwargs,
Expand All @@ -244,28 +335,74 @@ def __init__(
dummy=False,
verbose=True,
num_gen=None,
save_dir="results",
match_type="cdvae_coverage",
**match_kwargs,
):
self.dummy = dummy
self.verbose = verbose
self.num_gen = num_gen
self.save_dir = save_dir
self.match_type = match_type
self.match_kwargs = match_kwargs

Path(self.save_dir).mkdir(exist_ok=True, parents=True)

self.mpt = MPTimeSplit(target="energy_above_hull")
self.folds = self.mpt.folds
self.gms: List[Optional[GenMetrics]] = [None] * len(self.folds)
self.recorded_metrics = {}

def load_fingerprints(self, dummy=False):

comp_url = DUMMY_COMP_URL if dummy else FULL_COMP_URL
struct_url = DUMMY_STRUCT_URL if dummy else FULL_STRUCT_URL
comp_name = DUMMY_COMP_NAME if dummy else FULL_COMP_NAME
struct_name = DUMMY_STRUCT_NAME if dummy else FULL_STRUCT_NAME

read_csv_kwargs = dict(index_col="material_id", sep=",")
self.comp_fingerprints_df = ensure_csv(
DATA_HOME,
name=comp_name,
url=comp_url,
read_csv_kwargs=read_csv_kwargs,
)
self.struct_fingerprints_df = ensure_csv(
DATA_HOME,
name=struct_name,
url=struct_url,
read_csv_kwargs=read_csv_kwargs,
)

return self.comp_fingerprints_df, self.struct_fingerprints_df

def get_train_and_val_data(self, fold, include_val=False):
self.mpt.load(dummy=self.dummy)

if self.recorded_metrics == {}:
self.mpt.load(dummy=self.dummy)
(
self.train_inputs,
self.val_inputs,
self.train_outputs,
self.val_outputs,
) = self.mpt.get_train_and_val_data(fold)

if self.match_type == "cdvae_coverage":
comp_fps, struct_fps = self.load_fingerprints()

self.train_comp_fingerprints, self.val_comp_fingerprints = [
comp_fps.iloc[tvs].values for tvs in self.mpt.trainval_splits[fold]
]

self.train_struct_fingerprints, self.val_struct_fingerprints = [
struct_fps.iloc[tvs].values for tvs in self.mpt.trainval_splits[fold]
]
elif self.match_type == "StructureMatcher":
self.train_comp_fingerprints = None
self.val_comp_fingerprints = None
self.train_struct_fingerprints = None
self.val_struct_fingerprints = None

if include_val:
return self.train_inputs, self.val_inputs

Expand All @@ -280,6 +417,10 @@ def evaluate_and_record(self, fold, gen_structures, test_pred_structures=None):
self.train_inputs.tolist(),
self.val_inputs.tolist(),
gen_structures,
train_comp_fingerprints=self.train_comp_fingerprints,
test_comp_fingerprints=self.val_comp_fingerprints,
train_struct_fingerprints=self.train_struct_fingerprints,
test_struct_fingerprints=self.val_struct_fingerprints,
test_pred_structures=test_pred_structures,
verbose=self.verbose,
match_type=self.match_type,
Expand Down Expand Up @@ -477,3 +618,13 @@ def run():
# IN_COLAB = True
# except ImportError:
# IN_COLAB = False

# elif test_comp_fingerprints and (gen_comp_fingerprints is None):
# self.gen_comp_fingerprints = test_comp_fingerprints
# self.gen_struct_fingerprints = test_struct_fingerprints
# self.symmetric = True
# elif test_comp_fingerprints and gen_comp_fingerprints:
# assert gen_comp_fingerprints is not None
# self.gen_comp_fingerprints = gen_comp_fingerprints
# self.gen_struct_fingerprints = gen_struct_fingerprints
# self.symmetric = False
Loading

0 comments on commit b398bf2

Please sign in to comment.