Skip to content

Commit

Permalink
logic for symmetric matrices and calculating fingerprints
Browse files Browse the repository at this point in the history
for latter, if fingerprints are not already calculated
  • Loading branch information
sgbaird committed Aug 6, 2022
1 parent 84d42b5 commit 8a22721
Showing 1 changed file with 58 additions and 26 deletions.
84 changes: 58 additions & 26 deletions src/matbench_genmetrics/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from typing import List, Optional

import numpy as np
import pystow
from mp_time_split.core import MPTimeSplit
from pymatgen.core.structure import Structure
from pystow import ensure_csv
Expand Down Expand Up @@ -56,10 +55,10 @@ def fib(n):

IN_COLAB = "google.colab" in sys.modules

FULL_COMP_SNAPSHOT_NAME = "comp_fingerprints.csv"
DUMMY_COMP_SNAPSHOT_NAME = "dummy_comp_fingerprints.csv"
FULL_STRUCT_SNAPSHOT_NAME = "struct_fingerprints.csv"
DUMMY_STRUCT_SNAPSHOT_NAME = "dummy_struct_fingerprints.csv"
FULL_COMP_NAME = "comp_fingerprints.csv"
DUMMY_COMP_NAME = "dummy_comp_fingerprints.csv"
FULL_STRUCT_NAME = "struct_fingerprints.csv"
DUMMY_STRUCT_NAME = "dummy_struct_fingerprints.csv"

FULL_COMP_CHECKSUM_FROZEN = "0d714081a8f0bc53af84b0ce96d3536f"
DUMMY_COMP_CHECKSUM_FROZEN = "5630a3bfc7cbeac0cb3d7897b02aae9f"
Expand All @@ -72,7 +71,7 @@ def fib(n):
DUMMY_STRUCT_URL = "https://figshare.com/ndownloader/files/36582177"


MBGM_HOME = pystow.join("matbench-genmetrics")
DATA_HOME = "matbench-genmetrics"


class GenMatcher(object):
Expand All @@ -98,21 +97,34 @@ def __init__(
self.match_type = match_type
self.match_kwargs = match_kwargs

if test_structures and (gen_structures is None):
if gen_structures is None:
self.gen_structures = test_structures
self.symmetric = True
elif test_structures and gen_structures:
else:
self.gen_structures = gen_structures
self.symmetric = False
elif test_comp_fingerprints and (gen_comp_fingerprints is None):
self.gen_comp_fingerprints = test_comp_fingerprints
self.gen_struct_fingerprints = test_struct_fingerprints
self.symmetric = True
elif test_comp_fingerprints and gen_comp_fingerprints:
assert gen_comp_fingerprints is not None
self.gen_comp_fingerprints = gen_comp_fingerprints
self.gen_struct_fingerprints = gen_struct_fingerprints
self.symmetric = False

# featurize test and/or gen structures if features not provided
if self.match_type == "cdvae_coverage":
if test_comp_fingerprints is None or test_struct_fingerprints is None:
(
self.test_comp_fingerprints,
self.test_struct_fingerprints,
) = featurize_comp_struct(self.test_structures)

if self.symmetric:
self.gen_comp_fingerprints, self.gen_struct_fingerprints = (
self.test_comp_fingerprints,
self.test_struct_fingerprints,
)
elif gen_comp_fingerprints is None or gen_struct_fingerprints is None:
(
self.gen_comp_fingerprints,
self.gen_struct_fingerprints,
) = featurize_comp_struct(self.gen_structures)
else:
self.gen_comp_fingerprints = gen_comp_fingerprints
self.gen_struct_fingerprints = gen_struct_fingerprints

self.num_test = len(self.test_structures)
self.num_gen = len(self.gen_structures)
Expand Down Expand Up @@ -345,14 +357,24 @@ def load_fingerprints(self, dummy=False):

comp_url = DUMMY_COMP_URL if dummy else FULL_COMP_URL
struct_url = DUMMY_STRUCT_URL if dummy else FULL_STRUCT_URL
comp_name = DUMMY_COMP_NAME if dummy else FULL_COMP_NAME
struct_name = DUMMY_STRUCT_NAME if dummy else FULL_STRUCT_NAME

read_csv_kwargs = dict(index_col="material_id", sep=",")
self.comp_fingerprints_df = ensure_csv(
DATA_HOME,
name=comp_name,
url=comp_url,
read_csv_kwargs=read_csv_kwargs,
)
self.struct_fingerprints_df = ensure_csv(
DATA_HOME,
name=struct_name,
url=struct_url,
read_csv_kwargs=read_csv_kwargs,
)

self.comp_fps_df = ensure_csv(MBGM_HOME, url=comp_url)
self.struct_fps_df = ensure_csv(MBGM_HOME, url=struct_url)

self.comp_fingerprints = self.comp_fps_df.drop("material_id", axis=1).values
self.struct_fingerprints = self.struct_fps_df.drop("material_id", axis=1).values

return self.comp_fingerprints, self.struct_fingerprints
return self.comp_fingerprints_df, self.struct_fingerprints_df

def get_train_and_val_data(self, fold, include_val=False):

Expand All @@ -369,11 +391,11 @@ def get_train_and_val_data(self, fold, include_val=False):
comp_fps, struct_fps = self.load_fingerprints()

self.train_comp_fingerprints, self.val_comp_fingerprints = [
comp_fps.iloc[tvs] for tvs in self.mpt.trainval_splits[fold]
comp_fps.iloc[tvs].values for tvs in self.mpt.trainval_splits[fold]
]

self.train_struct_fingerprints, self.val_struct_fingerprints = [
struct_fps.iloc[tvs] for tvs in self.mpt.trainval_splits[fold]
struct_fps.iloc[tvs].values for tvs in self.mpt.trainval_splits[fold]
]
elif self.match_type == "StructureMatcher":
self.train_comp_fingerprints = None
Expand Down Expand Up @@ -596,3 +618,13 @@ def run():
# IN_COLAB = True
# except ImportError:
# IN_COLAB = False

# elif test_comp_fingerprints and (gen_comp_fingerprints is None):
# self.gen_comp_fingerprints = test_comp_fingerprints
# self.gen_struct_fingerprints = test_struct_fingerprints
# self.symmetric = True
# elif test_comp_fingerprints and gen_comp_fingerprints:
# assert gen_comp_fingerprints is not None
# self.gen_comp_fingerprints = gen_comp_fingerprints
# self.gen_struct_fingerprints = gen_struct_fingerprints
# self.symmetric = False

0 comments on commit 8a22721

Please sign in to comment.