diff --git a/README.md b/README.md
index a461dbc9..d5a47be9 100644
--- a/README.md
+++ b/README.md
@@ -1,4 +1,3 @@
-[![name](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/knc6/jarvis-tools-notebooks/blob/master/jarvis-tools-notebooks/Training_ALIGNN_model_example.ipynb)
![alt text](https://github.com/usnistgov/alignn/actions/workflows/main.yml/badge.svg)
[![codecov](https://codecov.io/gh/usnistgov/alignn/branch/main/graph/badge.svg?token=S5X4OYC80V)](https://codecov.io/gh/usnistgov/alignn)
[![PyPI version](https://badge.fury.io/py/alignn.svg)](https://badge.fury.io/py/alignn)
@@ -19,7 +18,6 @@
* [Installation](#install)
* [Examples](#example)
* [Pre-trained models](#pretrained)
-* [Quick start using colab](#colab)
* [JARVIS-ALIGNN webapp](#webapp)
* [ALIGNN-FF & ASE Calculator](#alignnff)
* [Peformances on a few datasets](#performances)
@@ -111,6 +109,18 @@ pip install dgl==1.0.1+cu117 -f https://data.dgl.ai/wheels/cu117/repo.html
Examples
---------
+
+
+
+| Notebooks | Google Colab | Descriptions |
+| ---------------------------------------------------------------------------------------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
+| [Regression model](https://colab.research.google.com/github/knc6/jarvis-tools-notebooks/blob/master/jarvis-tools-notebooks/alignn_jarvis_leaderboard.ipynb) | [![Open in Google Colab]](https://colab.research.google.com/github/knc6/jarvis-tools-notebooks/blob/master/jarvis-tools-notebooks/alignn_jarvis_leaderboard.ipynb) | Examples for developing single output regression model for exfoliation energies of 2D materials. |
+| [MLFF](https://colab.research.google.com/github/knc6/jarvis-tools-notebooks/blob/master/jarvis-tools-notebooks/Train_ALIGNNFF_Mlearn.ipynb) | [![Open in Google Colab]](https://colab.research.google.com/github/knc6/jarvis-tools-notebooks/blob/master/jarvis-tools-notebooks/Train_ALIGNNFF_Mlearn.ipynb) | Examples of training a machine learning force field for Silicon. |
+| [Miscellaneous tasks](https://colab.research.google.com/github/knc6/jarvis-tools-notebooks/blob/master/jarvis-tools-notebooks/Training_ALIGNN_model_example.ipynb) | [![Open in Google Colab]](https://colab.research.google.com/github/knc6/jarvis-tools-notebooks/blob/master/jarvis-tools-notebooks/Training_ALIGNN_model_example.ipynb) | Examples for developing single output (such as formation energy, bandgaps) or multi-output (such as phonon DOS, electron DOS) Regression or Classification (such as metal vs non-metal), Using several pretrained models. |
+
+
+[Open in Google Colab]: https://colab.research.google.com/assets/colab-badge.svg
+
Here, we provide examples for property prediction tasks, development of machine-learning force-fields (MLFF), usage of pre-trained property predictor, MLFFs, webapps etc.
#### Dataset preparation for property prediction tasks
@@ -174,18 +184,7 @@ An example of prediction formation energy per atom using JARVIS-DFT dataset trai
```
pretrained.py --model_name jv_formation_energy_peratom_alignn --file_format poscar --file_path alignn/examples/sample_data/POSCAR-JVASP-10.vasp
```
-
-Quick start using GoogleColab notebook example
------------------------------------------------
-
-The following [notebook](https://colab.research.google.com/github/knc6/jarvis-tools-notebooks/blob/master/jarvis-tools-notebooks/Training_ALIGNN_model_example.ipynb) provides an example of 1) installing ALIGNN model, 2) training the example data and 3) using the pretrained models. For this example, you don't need to install alignn package on your local computer/cluster, it requires a gmail account to login. Learn more about Google colab [here](https://colab.research.google.com/notebooks/intro.ipynb).
-
-[![name](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/knc6/jarvis-tools-notebooks/blob/master/jarvis-tools-notebooks/Training_ALIGNN_model_example.ipynb)
-
-
-The following [notebook](https://colab.research.google.com/github/knc6/jarvis-tools-notebooks/blob/master/jarvis-tools-notebooks/Train_ALIGNNFF_Mlearn.ipynb) provides an example of ALIGNN-FF model.
-For additional notebooks, checkout [JARVIS-Tools-Notebooks](https://github.com/JARVIS-Materials-Design/jarvis-tools-notebooks?tab=readme-ov-file#artificial-intelligencemachine-learning)
Web-app
diff --git a/alignn/__init__.py b/alignn/__init__.py
index 83674b0f..86b75d31 100644
--- a/alignn/__init__.py
+++ b/alignn/__init__.py
@@ -1,3 +1,3 @@
"""Version number."""
-__version__ = "2024.4.10"
+__version__ = "2024.4.20"
diff --git a/alignn/config.py b/alignn/config.py
index c8a21279..d7c807ae 100644
--- a/alignn/config.py
+++ b/alignn/config.py
@@ -207,7 +207,8 @@ class TrainingConfig(BaseSettings):
distributed: bool = False
data_parallel: bool = False
n_early_stopping: Optional[int] = None # typically 50
- output_dir: str = os.path.abspath(".") # typically 50
+ output_dir: str = os.path.abspath(".")
+ use_lmdb: bool = True
# alignn_layers: int = 4
# gcn_layers: int =4
# edge_input_features: int= 80
diff --git a/alignn/data.py b/alignn/data.py
index b5872501..c52d0c74 100644
--- a/alignn/data.py
+++ b/alignn/data.py
@@ -1,33 +1,19 @@
-"""Jarvis-dgl data loaders and DGLGraph utilities."""
+"""ALIGNN data loaders and DGLGraph utilities."""
import random
-from pathlib import Path
from typing import Optional
-
-# from typing import Dict, List, Optional, Set, Tuple
-
+from torch.utils.data.distributed import DistributedSampler
import os
import torch
-import dgl
import numpy as np
-import pandas as pd
-from jarvis.core.atoms import Atoms
-from alignn.graphs import Graph, StructureDataset
-
-# from jarvis.core.graphs import Graph, StructureDataset
from jarvis.db.figshare import data as jdata
-from torch.utils.data import DataLoader
from tqdm import tqdm
import math
from jarvis.db.jsonutils import dumpjson
-
-# from sklearn.pipeline import Pipeline
+from dgl.dataloading import GraphDataLoader
import pickle as pk
-
-# from sklearn.decomposition import PCA # ,KernelPCA
from sklearn.preprocessing import StandardScaler
-# use pandas progress_apply
tqdm.pandas()
@@ -68,96 +54,6 @@ def mean_absolute_deviation(data, axis=None):
return np.mean(np.absolute(data - np.mean(data, axis)), axis)
-def load_graphs(
- dataset=[],
- name: str = "dft_3d",
- neighbor_strategy: str = "k-nearest",
- cutoff: float = 8,
- cutoff_extra: float = 3,
- max_neighbors: int = 12,
- cachedir: Optional[Path] = None,
- use_canonize: bool = False,
- id_tag="jid",
- # extra_feats_json=None,
-):
- """Construct crystal graphs.
-
- Load only atomic number node features
- and bond displacement vector edge features.
-
- Resulting graphs have scheme e.g.
- ```
- Graph(num_nodes=12, num_edges=156,
- ndata_schemes={'atom_features': Scheme(shape=(1,)}
- edata_schemes={'r': Scheme(shape=(3,)})
- ```
- """
-
- def atoms_to_graph(atoms):
- """Convert structure dict to DGLGraph."""
- structure = (
- Atoms.from_dict(atoms) if isinstance(atoms, dict) else atoms
- )
- return Graph.atom_dgl_multigraph(
- structure,
- cutoff=cutoff,
- cutoff_extra=cutoff_extra,
- atom_features="atomic_number",
- max_neighbors=max_neighbors,
- compute_line_graph=False,
- use_canonize=use_canonize,
- neighbor_strategy=neighbor_strategy,
- )
-
- if cachedir is not None:
- cachefile = cachedir / f"{name}-{neighbor_strategy}.bin"
- else:
- cachefile = None
-
- if cachefile is not None and cachefile.is_file():
- graphs, labels = dgl.load_graphs(str(cachefile))
- else:
- # print('dataset',dataset,type(dataset))
- print("Converting to graphs!")
- graphs = []
- # columns=dataset.columns
- for ii, i in tqdm(dataset.iterrows()):
- # print('iooooo',i)
- atoms = i["atoms"]
- structure = (
- Atoms.from_dict(atoms) if isinstance(atoms, dict) else atoms
- )
- g = Graph.atom_dgl_multigraph(
- structure,
- cutoff=cutoff,
- cutoff_extra=cutoff_extra,
- atom_features="atomic_number",
- max_neighbors=max_neighbors,
- compute_line_graph=False,
- use_canonize=use_canonize,
- neighbor_strategy=neighbor_strategy,
- id=i[id_tag],
- )
- # print ('ii',ii)
- if "extra_features" in i:
- natoms = len(atoms["elements"])
- # if "extra_features" in columns:
- g.ndata["extra_features"] = torch.tensor(
- [i["extra_features"] for n in range(natoms)]
- ).type(torch.get_default_dtype())
- graphs.append(g)
-
- # df = pd.DataFrame(dataset)
- # print ('df',df)
-
- # graphs = df["atoms"].progress_apply(atoms_to_graph).values
- # print ('graphs',graphs,graphs[0])
- if cachefile is not None:
- dgl.save_graphs(str(cachefile), graphs.tolist())
-
- return graphs
-
-
def get_id_train_val_test(
total_size=1000,
split_seed=123,
@@ -219,63 +115,6 @@ def get_id_train_val_test(
return id_train, id_val, id_test
-def get_torch_dataset(
- dataset=[],
- id_tag="jid",
- target="",
- target_atomwise="",
- target_grad="",
- target_stress="",
- neighbor_strategy="",
- atom_features="",
- use_canonize="",
- name="",
- line_graph="",
- cutoff=8.0,
- cutoff_extra=3.0,
- max_neighbors=12,
- classification=False,
- output_dir=".",
- tmp_name="dataset",
-):
- """Get Torch Dataset."""
- df = pd.DataFrame(dataset)
- # df['natoms']=df['atoms'].apply(lambda x: len(x['elements']))
- # print(" data df", df)
- vals = np.array([ii[target] for ii in dataset]) # df[target].values
- print("data range", np.max(vals), np.min(vals))
- f = open(os.path.join(output_dir, tmp_name + "_data_range"), "w")
- line = "Max=" + str(np.max(vals)) + "\n"
- f.write(line)
- line = "Min=" + str(np.min(vals)) + "\n"
- f.write(line)
- f.close()
-
- graphs = load_graphs(
- df,
- name=name,
- neighbor_strategy=neighbor_strategy,
- use_canonize=use_canonize,
- cutoff=cutoff,
- cutoff_extra=cutoff_extra,
- max_neighbors=max_neighbors,
- id_tag=id_tag,
- )
- data = StructureDataset(
- df,
- graphs,
- target=target,
- target_atomwise=target_atomwise,
- target_grad=target_grad,
- target_stress=target_stress,
- atom_features=atom_features,
- line_graph=line_graph,
- id_tag=id_tag,
- classification=classification,
- )
- return data
-
-
def get_train_val_loaders(
dataset: str = "dft_3d",
dataset_array=None,
@@ -298,9 +137,10 @@ def get_train_val_loaders(
workers: int = 0,
pin_memory: bool = True,
save_dataloader: bool = False,
- filename: str = "sample",
+ filename: str = "./",
id_tag: str = "jid",
- use_canonize: bool = False,
+ use_canonize: bool = True,
+ # use_ddp: bool = False,
cutoff: float = 8.0,
cutoff_extra: float = 3.0,
max_neighbors: int = 12,
@@ -310,11 +150,24 @@ def get_train_val_loaders(
keep_data_order=False,
output_features=1,
output_dir=None,
+ world_size=0,
+ rank=0,
+ use_lmdb: bool = True,
):
"""Help function to set up JARVIS train and val dataloaders."""
+ if use_lmdb:
+ print("Using LMDB dataset.")
+ from alignn.lmdb_dataset import get_torch_dataset
+ else:
+ print("Not using LMDB dataset, memory footprint maybe high.")
+ from alignn.dataset import get_torch_dataset
train_sample = filename + "_train.data"
val_sample = filename + "_val.data"
test_sample = filename + "_test.data"
+ if os.path.exists(train_sample):
+ print("If you are training from scratch, run")
+ cmd = "rm -r " + train_sample + " " + val_sample + " " + test_sample
+ print(cmd)
# print ('output_dir data',output_dir)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
@@ -498,6 +351,19 @@ def get_train_val_loaders(
print("Data error", exp)
pass
+ if world_size > 1:
+ use_ddp = True
+ train_sampler = DistributedSampler(
+ dataset_train, num_replicas=world_size, rank=rank
+ )
+ val_sampler = DistributedSampler(
+ dataset_val, num_replicas=world_size, rank=rank
+ )
+ else:
+ use_ddp = False
+ train_sampler = None
+ val_sampler = None
+ tmp_name = filename + "train_data"
train_data = get_torch_dataset(
dataset=dataset_train,
id_tag=id_tag,
@@ -515,8 +381,11 @@ def get_train_val_loaders(
max_neighbors=max_neighbors,
classification=classification_threshold is not None,
output_dir=output_dir,
- tmp_name="train_data",
+ sampler=train_sampler,
+ tmp_name=tmp_name,
+ # tmp_name="train_data",
)
+ tmp_name = filename + "val_data"
val_data = (
get_torch_dataset(
dataset=dataset_val,
@@ -532,14 +401,17 @@ def get_train_val_loaders(
line_graph=line_graph,
cutoff=cutoff,
cutoff_extra=cutoff_extra,
+ sampler=val_sampler,
max_neighbors=max_neighbors,
classification=classification_threshold is not None,
output_dir=output_dir,
- tmp_name="val_data",
+ tmp_name=tmp_name,
+ # tmp_name="val_data",
)
if len(dataset_val) > 0
else None
)
+ tmp_name = filename + "test_data"
test_data = (
get_torch_dataset(
dataset=dataset_test,
@@ -558,7 +430,8 @@ def get_train_val_loaders(
max_neighbors=max_neighbors,
classification=classification_threshold is not None,
output_dir=output_dir,
- tmp_name="test_data",
+ tmp_name=tmp_name,
+ # tmp_name="test_data",
)
if len(dataset_test) > 0
else None
@@ -570,7 +443,8 @@ def get_train_val_loaders(
collate_fn = train_data.collate_line_graph
# use a regular pytorch dataloader
- train_loader = DataLoader(
+ train_loader = GraphDataLoader(
+ # train_loader = DataLoader(
train_data,
batch_size=batch_size,
shuffle=True,
@@ -578,9 +452,11 @@ def get_train_val_loaders(
drop_last=True,
num_workers=workers,
pin_memory=pin_memory,
+ use_ddp=use_ddp,
)
- val_loader = DataLoader(
+ val_loader = GraphDataLoader(
+ # val_loader = DataLoader(
val_data,
batch_size=batch_size,
shuffle=False,
@@ -588,10 +464,12 @@ def get_train_val_loaders(
drop_last=True,
num_workers=workers,
pin_memory=pin_memory,
+ use_ddp=use_ddp,
)
test_loader = (
- DataLoader(
+ GraphDataLoader(
+ # DataLoader(
test_data,
batch_size=1,
shuffle=False,
@@ -599,6 +477,7 @@ def get_train_val_loaders(
drop_last=False,
num_workers=workers,
pin_memory=pin_memory,
+ use_ddp=use_ddp,
)
if len(dataset_test) > 0
else None
diff --git a/alignn/dataset.py b/alignn/dataset.py
new file mode 100644
index 00000000..6baec251
--- /dev/null
+++ b/alignn/dataset.py
@@ -0,0 +1,164 @@
+"""Module to prepare ALIGNN dataset."""
+
+from pathlib import Path
+from typing import Optional
+import os
+import torch
+import dgl
+import numpy as np
+import pandas as pd
+from jarvis.core.atoms import Atoms
+from alignn.graphs import Graph, StructureDataset
+from tqdm import tqdm
+
+tqdm.pandas()
+
+
+def load_graphs(
+ dataset=[],
+ name: str = "dft_3d",
+ neighbor_strategy: str = "k-nearest",
+ cutoff: float = 8,
+ cutoff_extra: float = 3,
+ max_neighbors: int = 12,
+ cachedir: Optional[Path] = None,
+ use_canonize: bool = False,
+ id_tag="jid",
+ # extra_feats_json=None,
+ map_size=1e12,
+):
+ """Construct crystal graphs.
+
+ Load only atomic number node features
+ and bond displacement vector edge features.
+
+ Resulting graphs have scheme e.g.
+ ```
+ Graph(num_nodes=12, num_edges=156,
+ ndata_schemes={'atom_features': Scheme(shape=(1,)}
+ edata_schemes={'r': Scheme(shape=(3,)})
+ ```
+ """
+
+ def atoms_to_graph(atoms):
+ """Convert structure dict to DGLGraph."""
+ structure = (
+ Atoms.from_dict(atoms) if isinstance(atoms, dict) else atoms
+ )
+ return Graph.atom_dgl_multigraph(
+ structure,
+ cutoff=cutoff,
+ cutoff_extra=cutoff_extra,
+ atom_features="atomic_number",
+ max_neighbors=max_neighbors,
+ compute_line_graph=False,
+ use_canonize=use_canonize,
+ neighbor_strategy=neighbor_strategy,
+ )
+
+ if cachedir is not None:
+ cachefile = cachedir / f"{name}-{neighbor_strategy}.bin"
+ else:
+ cachefile = None
+
+ if cachefile is not None and cachefile.is_file():
+ graphs, labels = dgl.load_graphs(str(cachefile))
+ else:
+ # print('dataset',dataset,type(dataset))
+ print("Converting to graphs!")
+ graphs = []
+ # columns=dataset.columns
+ for ii, i in tqdm(dataset.iterrows(), total=len(dataset)):
+ # print('iooooo',i)
+ atoms = i["atoms"]
+ structure = (
+ Atoms.from_dict(atoms) if isinstance(atoms, dict) else atoms
+ )
+ g = Graph.atom_dgl_multigraph(
+ structure,
+ cutoff=cutoff,
+ cutoff_extra=cutoff_extra,
+ atom_features="atomic_number",
+ max_neighbors=max_neighbors,
+ compute_line_graph=False,
+ use_canonize=use_canonize,
+ neighbor_strategy=neighbor_strategy,
+ id=i[id_tag],
+ )
+ # print ('ii',ii)
+ if "extra_features" in i:
+ natoms = len(atoms["elements"])
+ # if "extra_features" in columns:
+ g.ndata["extra_features"] = torch.tensor(
+ [i["extra_features"] for n in range(natoms)]
+ ).type(torch.get_default_dtype())
+ graphs.append(g)
+
+ # df = pd.DataFrame(dataset)
+ # print ('df',df)
+
+ # graphs = df["atoms"].progress_apply(atoms_to_graph).values
+ # print ('graphs',graphs,graphs[0])
+ if cachefile is not None:
+ dgl.save_graphs(str(cachefile), graphs.tolist())
+
+ return graphs
+
+
+def get_torch_dataset(
+ dataset=[],
+ id_tag="jid",
+ target="",
+ target_atomwise="",
+ target_grad="",
+ target_stress="",
+ neighbor_strategy="",
+ atom_features="",
+ use_canonize="",
+ name="",
+ line_graph="",
+ cutoff=8.0,
+ cutoff_extra=3.0,
+ max_neighbors=12,
+ classification=False,
+ output_dir=".",
+ tmp_name="dataset",
+ sampler=None,
+):
+ """Get Torch Dataset."""
+ df = pd.DataFrame(dataset)
+ # df['natoms']=df['atoms'].apply(lambda x: len(x['elements']))
+ # print(" data df", df)
+ vals = np.array([ii[target] for ii in dataset]) # df[target].values
+ print("data range", np.max(vals), np.min(vals))
+ f = open(os.path.join(output_dir, tmp_name + "_data_range"), "w")
+ line = "Max=" + str(np.max(vals)) + "\n"
+ f.write(line)
+ line = "Min=" + str(np.min(vals)) + "\n"
+ f.write(line)
+ f.close()
+
+ graphs = load_graphs(
+ df,
+ name=name,
+ neighbor_strategy=neighbor_strategy,
+ use_canonize=use_canonize,
+ cutoff=cutoff,
+ cutoff_extra=cutoff_extra,
+ max_neighbors=max_neighbors,
+ id_tag=id_tag,
+ )
+ data = StructureDataset(
+ df,
+ graphs,
+ target=target,
+ target_atomwise=target_atomwise,
+ target_grad=target_grad,
+ target_stress=target_stress,
+ atom_features=atom_features,
+ line_graph=line_graph,
+ id_tag=id_tag,
+ classification=classification,
+ sampler=sampler,
+ )
+ return data
diff --git a/alignn/examples/sample_data_ff/config_example_atomwise.json b/alignn/examples/sample_data_ff/config_example_atomwise.json
index 7c761a4d..2b915ddb 100644
--- a/alignn/examples/sample_data_ff/config_example_atomwise.json
+++ b/alignn/examples/sample_data_ff/config_example_atomwise.json
@@ -31,12 +31,13 @@
"progress": true,
"log_tensorboard": false,
"standard_scalar_and_pca": false,
- "use_canonize": false,
+ "use_canonize": true,
"num_workers": 0,
"cutoff": 8.0,
"max_neighbors": 12,
"keep_data_order": true,
"distributed":false,
+ "use_lmdb": true,
"model": {
"name": "alignn_atomwise",
"atom_input_features": 92,
@@ -48,7 +49,10 @@
"graphwise_weight":0.85,
"gradwise_weight":0.05,
"atomwise_weight":0.0,
- "stresswise_weight":0.05
-
+ "stresswise_weight":0.05,
+ "add_reverse_forces":true,
+ "lg_on_fly":true
+
+
}
}
diff --git a/alignn/graphs.py b/alignn/graphs.py
index 74a10b7f..53e772ab 100644
--- a/alignn/graphs.py
+++ b/alignn/graphs.py
@@ -13,6 +13,7 @@
# from jarvis.core.atoms import Atoms
from collections import defaultdict
from typing import List, Tuple, Sequence, Optional
+from dgl.data import DGLDataset
import torch
import dgl
@@ -711,7 +712,7 @@ def compute_bond_cosines(edges):
return {"h": bond_cosine}
-class StructureDataset(torch.utils.data.Dataset):
+class StructureDataset(DGLDataset):
"""Dataset of crystal DGLGraphs."""
def __init__(
@@ -727,6 +728,7 @@ def __init__(
line_graph=False,
classification=False,
id_tag="jid",
+ sampler=None,
):
"""Pytorch Dataset for atomistic graphs.
diff --git a/alignn/lmdb_dataset.py b/alignn/lmdb_dataset.py
new file mode 100644
index 00000000..328395c8
--- /dev/null
+++ b/alignn/lmdb_dataset.py
@@ -0,0 +1,181 @@
+"""Module to prepare LMDB ALIGNN dataset."""
+
+import os
+import numpy as np
+import lmdb
+from jarvis.core.atoms import Atoms
+from jarvis.db.figshare import data
+from alignn.graphs import Graph
+import pickle as pk
+from torch.utils.data import Dataset
+import torch
+from tqdm import tqdm
+from typing import List, Tuple
+import dgl
+
+
+def prepare_line_graph_batch(
+ batch: Tuple[Tuple[dgl.DGLGraph, dgl.DGLGraph], torch.Tensor],
+ device=None,
+ non_blocking=False,
+):
+ """Send line graph batch to device.
+
+ Note: the batch is a nested tuple, with the graph and line graph together
+ """
+ g, lg, t, id = batch
+ batch = (
+ (
+ g.to(device, non_blocking=non_blocking),
+ lg.to(device, non_blocking=non_blocking),
+ ),
+ t.to(device, non_blocking=non_blocking),
+ )
+
+ return batch
+
+
+class TorchLMDBDataset(Dataset):
+ """Dataset of crystal DGLGraphs using LMDB."""
+
+ def __init__(self, lmdb_path="", ids=[]):
+ """Intitialize with path and ids array."""
+ super(TorchLMDBDataset, self).__init__()
+ self.lmdb_path = lmdb_path
+ self.ids = ids
+ self.env = lmdb.open(self.lmdb_path, readonly=True, lock=False)
+ with self.env.begin() as txn:
+ self.length = txn.stat()["entries"]
+ self.prepare_batch = prepare_line_graph_batch
+
+ def __len__(self):
+ """Get length."""
+ return self.length
+
+ def __getitem__(self, idx):
+ """Get sample."""
+ with self.env.begin() as txn:
+ serialized_data = txn.get(f"{idx}".encode())
+ graph, line_graph, label = pk.loads(serialized_data)
+ return graph, line_graph, label
+
+ def close(self):
+ """Close connection."""
+ self.env.close()
+
+ def __del__(self):
+ """Delete connection."""
+ self.close()
+
+ @staticmethod
+ def collate(samples: List[Tuple[dgl.DGLGraph, torch.Tensor]]):
+ """Dataloader helper to batch graphs cross `samples`."""
+ graphs, labels = map(list, zip(*samples))
+ batched_graph = dgl.batch(graphs)
+ return batched_graph, torch.tensor(labels)
+
+ @staticmethod
+ def collate_line_graph(
+ samples: List[Tuple[dgl.DGLGraph, dgl.DGLGraph, torch.Tensor]]
+ ):
+ """Dataloader helper to batch graphs cross `samples`."""
+ graphs, line_graphs, labels = map(list, zip(*samples))
+ batched_graph = dgl.batch(graphs)
+ batched_line_graph = dgl.batch(line_graphs)
+ if len(labels[0].size()) > 0:
+ return batched_graph, batched_line_graph, torch.stack(labels)
+ else:
+ return batched_graph, batched_line_graph, torch.tensor(labels)
+
+
+def get_torch_dataset(
+ dataset=[],
+ id_tag="jid",
+ target="",
+ target_atomwise="",
+ target_grad="",
+ target_stress="",
+ neighbor_strategy="k-nearest",
+ atom_features="cgcnn",
+ use_canonize="",
+ name="",
+ line_graph=True,
+ cutoff=8.0,
+ cutoff_extra=3.0,
+ max_neighbors=12,
+ classification=False,
+ sampler=None,
+ output_dir=".",
+ tmp_name="dataset",
+ map_size=1e12,
+ read_existing=True,
+):
+ """Get Torch Dataset with LMDB."""
+ vals = np.array([ii[target] for ii in dataset]) # df[target].values
+ print("data range", np.max(vals), np.min(vals))
+ f = open(os.path.join(output_dir, tmp_name + "_data_range"), "w")
+ line = "Max=" + str(np.max(vals)) + "\n"
+ f.write(line)
+ line = "Min=" + str(np.min(vals)) + "\n"
+ f.write(line)
+ f.close()
+ ids = []
+ if os.path.exists(tmp_name) and read_existing:
+ for idx, (d) in tqdm(enumerate(dataset), total=len(dataset)):
+ ids.append(d[id_tag])
+ dat = TorchLMDBDataset(lmdb_path=tmp_name, ids=ids)
+ print("Reading dataset", tmp_name)
+ return dat
+ ids = []
+ env = lmdb.open(tmp_name, map_size=int(map_size))
+ with env.begin(write=True) as txn:
+ for idx, (d) in tqdm(enumerate(dataset), total=len(dataset)):
+ ids.append(d[id_tag])
+ g, lg = Graph.atom_dgl_multigraph(
+ Atoms.from_dict(d["atoms"]),
+ cutoff=float(cutoff),
+ max_neighbors=max_neighbors,
+ atom_features=atom_features,
+ compute_line_graph=line_graph,
+ use_canonize=use_canonize,
+ cutoff_extra=cutoff_extra,
+ )
+ label = torch.tensor(d[target]).type(torch.get_default_dtype())
+ # print('label',label,label.view(-1).long())
+ if classification:
+ label = label.long()
+ # label = label.view(-1).long()
+ if "extra_features" in d:
+ natoms = len(d["atoms"]["elements"])
+ g.ndata["extra_features"] = torch.tensor(
+ [d["extra_features"] for n in range(natoms)]
+ ).type(torch.get_default_dtype())
+ if target_atomwise is not None and target_atomwise != "":
+ g.ndata[target_atomwise] = torch.tensor(
+ np.array(d[target_atomwise])
+ ).type(torch.get_default_dtype())
+ if target_grad is not None and target_grad != "":
+ g.ndata[target_grad] = torch.tensor(
+ np.array(d[target_grad])
+ ).type(torch.get_default_dtype())
+ if target_stress is not None and target_stress != "":
+ stress = np.array(d[target_stress])
+ g.ndata[target_stress] = torch.tensor(
+ np.array([stress for ii in range(g.number_of_nodes())])
+ ).type(torch.get_default_dtype())
+
+ # labels.append(label)
+ serialized_data = pk.dumps((g, lg, label))
+ txn.put(f"{idx}".encode(), serialized_data)
+
+ env.close()
+ lmdb_dataset = TorchLMDBDataset(lmdb_path=tmp_name, ids=ids)
+ return lmdb_dataset
+
+
+if __name__ == "__main__":
+ dataset = data("dft_2d")
+ lmdb_dataset = get_torch_dataset(
+ dataset=dataset, target="formation_energy_peratom"
+ )
+ print(lmdb_dataset)
diff --git a/alignn/pretrained.py b/alignn/pretrained.py
index b5ddd9c8..6e49e5b2 100644
--- a/alignn/pretrained.py
+++ b/alignn/pretrained.py
@@ -6,7 +6,6 @@
import zipfile
from tqdm import tqdm
from alignn.models.alignn import ALIGNN, ALIGNNConfig
-from alignn.data import get_torch_dataset
from torch.utils.data import DataLoader
import tempfile
import torch
@@ -17,6 +16,7 @@
from alignn.graphs import Graph
from jarvis.db.jsonutils import dumpjson
import pandas as pd
+from alignn.dataset import get_torch_dataset
# from jarvis.core.graphs import Graph
@@ -340,8 +340,16 @@ def get_multiple_predictions(
model=None,
model_name="jv_formation_energy_peratom_alignn",
print_freq=100,
+ # use_lmdb=True,
):
"""Use pretrained model on a number of structures."""
+ # if use_lmdb:
+ # print("Using LMDB dataset.")
+ # from alignn.lmdb_dataset import get_torch_dataset
+ # else:
+ # print("Not using LMDB dataset, memory footprint maybe high.")
+ # from alignn.dataset import get_torch_dataset
+
# import glob
# atoms_array=[]
# for i in glob.glob("alignn/examples/sample_data/*.vasp"):
diff --git a/alignn/tests/test_prop.py b/alignn/tests/test_prop.py
index aee4b638..bcc34c64 100644
--- a/alignn/tests/test_prop.py
+++ b/alignn/tests/test_prop.py
@@ -12,6 +12,8 @@
from alignn.train_alignn import train_for_folder
from jarvis.db.figshare import get_jid_data
from alignn.ff.ff import AlignnAtomwiseCalculator, default_path, revised_path
+import torch
+from jarvis.db.jsonutils import loadjson, dumpjson
plt.switch_backend("agg")
@@ -62,6 +64,7 @@ def test_models():
config["write_predictions"] = True
config["model"]["name"] = "alignn_atomwise"
+ config["filename"] = "X"
t1 = time.time()
result = train_dgl(config)
t2 = time.time()
@@ -73,6 +76,7 @@ def test_models():
print()
config["model"]["name"] = "alignn_atomwise"
+ config["filename"] = "Y"
config["classification_threshold"] = 0.0
t1 = time.time()
result = train_dgl(config)
@@ -127,7 +131,13 @@ def test_pretrained():
get_multiple_predictions(atoms_array=[Si, Si])
-def test_alignn_train():
+world_size = int(torch.cuda.device_count())
+
+
+def test_alignn_train_regression():
+ # Regression
+ cmd = "rm -rf *train_data *test_data *val_data"
+ os.system(cmd)
root_dir = os.path.abspath(
os.path.join(os.path.dirname(__file__), "../examples/sample_data/")
)
@@ -137,8 +147,18 @@ def test_alignn_train():
"../examples/sample_data/config_example.json",
)
)
- train_for_folder(root_dir=root_dir, config_name=config)
+ tmp = loadjson(config)
+ tmp["filename"] = "AA"
+ dumpjson(data=tmp, filename=config)
+ train_for_folder(
+ rank=0, world_size=world_size, root_dir=root_dir, config_name=config
+ )
+
+def test_alignn_train_regression_multi_out():
+ cmd = "rm -rf *train_data *test_data *val_data"
+ os.system(cmd)
+ # Regression multi-out
root_dir = os.path.abspath(
os.path.join(
os.path.dirname(__file__), "../examples/sample_data_multi_prop/"
@@ -150,8 +170,18 @@ def test_alignn_train():
"../examples/sample_data/config_example.json",
)
)
- train_for_folder(root_dir=root_dir, config_name=config)
+ tmp = loadjson(config)
+ tmp["filename"] = "BB"
+ dumpjson(data=tmp, filename=config)
+ train_for_folder(
+ rank=0, world_size=world_size, root_dir=root_dir, config_name=config
+ )
+
+def test_alignn_train_classification():
+ cmd = "rm -rf *train_data *test_data *val_data"
+ os.system(cmd)
+ # Classification
root_dir = os.path.abspath(
os.path.join(os.path.dirname(__file__), "../examples/sample_data/")
)
@@ -161,10 +191,22 @@ def test_alignn_train():
"../examples/sample_data/config_example.json",
)
)
+ tmp = loadjson(config)
+ tmp["filename"] = "A"
+ dumpjson(data=tmp, filename=config)
train_for_folder(
- root_dir=root_dir, config_name=config, classification_threshold=0.01
+ rank=0,
+ world_size=world_size,
+ root_dir=root_dir,
+ config_name=config,
+ classification_threshold=0.01,
)
+
+def test_alignn_train_ff():
+ cmd = "rm -rf *train_data *test_data *val_data"
+ os.system(cmd)
+ # FF
root_dir = os.path.abspath(
os.path.join(os.path.dirname(__file__), "../examples/sample_data_ff/")
)
@@ -174,7 +216,12 @@ def test_alignn_train():
"../examples/sample_data_ff/config_example_atomwise.json",
)
)
- train_for_folder(root_dir=root_dir, config_name=config)
+ tmp = loadjson(config)
+ tmp["filename"] = "B"
+ dumpjson(data=tmp, filename=config)
+ train_for_folder(
+ rank=0, world_size=world_size, root_dir=root_dir, config_name=config
+ )
def test_calculator():
@@ -234,8 +281,12 @@ def test_del_files():
for i in fnames:
cmd = "rm -r " + i
os.system(cmd)
+ cmd="rm -r *train_data *val_data *test_data"
+ os.system(cmd)
-
+# test_alignn_train_ff()
+# test_alignn_train_classification()
+# test_alignn_train()
# test_minor_configs()
# test_pretrained()
# test_runtime_training()
diff --git a/alignn/train.py b/alignn/train.py
index 4854c178..dd3675f1 100644
--- a/alignn/train.py
+++ b/alignn/train.py
@@ -5,6 +5,7 @@
then `tensorboard --logdir tb_logs/test` to monitor results...
"""
+from torch.nn.parallel import DistributedDataParallel as DDP
from functools import partial
from typing import Any, Dict, Union
import torch
@@ -29,6 +30,15 @@
torch.set_default_dtype(torch.float32)
+# def setup(rank, world_size):
+# """Set up multi GPU rank."""
+# os.environ["MASTER_ADDR"] = "localhost"
+# os.environ["MASTER_PORT"] = "12355"
+# # Initialize the distributed environment.
+# dist.init_process_group("nccl", rank=rank, world_size=world_size)
+# torch.cuda.set_device(rank)
+
+
def activated_output_transform(output):
"""Exponentiate output."""
y_pred, y = output
@@ -101,6 +111,8 @@ def train_dgl(
model: nn.Module = None,
# checkpoint_dir: Path = Path("./"),
train_val_test_loaders=[],
+ rank=0,
+ world_size=0,
# log_tensorboard: bool = False,
):
"""Training entry point for DGL networks.
@@ -108,20 +120,23 @@ def train_dgl(
`config` should conform to alignn.conf.TrainingConfig, and
if passed as a dict with matching keys, pydantic validation is used
"""
- print(config)
- if type(config) is dict:
- try:
- print(config)
- config = TrainingConfig(**config)
- except Exception as exp:
- print("Check", exp)
+ # print("rank", rank)
+ # setup(rank, world_size)
+ if rank == 0:
+ print("config:")
+ # print(config)
+ if type(config) is dict:
+ try:
+ print(config)
+ config = TrainingConfig(**config)
+ except Exception as exp:
+ print("Check", exp)
if not os.path.exists(config.output_dir):
os.makedirs(config.output_dir)
# checkpoint_dir = os.path.join(config.output_dir)
# deterministic = False
classification = False
- print("config:")
tmp = config.dict()
f = open(os.path.join(config.output_dir, "config.json"), "w")
f.write(json.dumps(tmp, indent=4))
@@ -135,6 +150,13 @@ def train_dgl(
line_graph = False
if config.model.alignn_layers > 0:
line_graph = True
+ if world_size > 1:
+ use_ddp = True
+ else:
+ use_ddp = False
+ device = "cpu"
+ if torch.cuda.is_available():
+ device = torch.device("cuda")
if not train_val_test_loaders:
# use input standardization for all real-valued feature sets
# print("config.neighbor_strategy",config.neighbor_strategy)
@@ -173,15 +195,16 @@ def train_dgl(
standard_scalar_and_pca=config.standard_scalar_and_pca,
keep_data_order=config.keep_data_order,
output_dir=config.output_dir,
+ use_lmdb=config.use_lmdb,
)
else:
train_loader = train_val_test_loaders[0]
val_loader = train_val_test_loaders[1]
test_loader = train_val_test_loaders[2]
prepare_batch = train_val_test_loaders[3]
- device = "cpu"
- if torch.cuda.is_available():
- device = torch.device("cuda")
+ # rank=0
+ if use_ddp:
+ device = torch.device(f"cuda:{rank}")
prepare_batch = partial(prepare_batch, device=device)
if classification:
config.model.classification = True
@@ -208,12 +231,12 @@ def train_dgl(
net = _model.get(config.model.name)(config.model)
else:
net = model
+
+ # print("net", net)
+ # print("device", device)
net.to(device)
- if config.data_parallel and torch.cuda.device_count() > 1:
- # For multi-GPU training make data_parallel:true in config.json file
- device_ids = [cid for cid in range(torch.cuda.device_count())]
- print("Let's use", torch.cuda.device_count(), "GPUs!")
- net = torch.nn.DataParallel(net, device_ids=device_ids).cuda()
+ if use_ddp:
+ net = DDP(net, device_ids=[rank], find_unused_parameters=True)
# group parameters to skip weight decay for bias and batchnorm
params = group_decay(net)
optimizer = setup_optimizer(params, config)
@@ -452,23 +475,7 @@ def get_batch_errors(dat=[]):
scheduler.step()
train_final_time = time.time()
train_ep_time = train_final_time - train_init_time
- print(
- "TrainLoss",
- "Epoch",
- e,
- "total",
- running_loss,
- "out",
- mean_out,
- "atom",
- mean_atom,
- "grad",
- mean_grad,
- "stress",
- mean_stress,
- "time",
- train_ep_time,
- )
+ # if rank == 0: # or world_size == 1:
history_train.append([mean_out, mean_atom, mean_grad, mean_stress])
dumpjson(
filename=os.path.join(config.output_dir, "history_train.json"),
@@ -587,260 +594,293 @@ def get_batch_errors(dat=[]):
data=val_result,
)
best_model = net
- print(
- "ValLoss",
- "Epoch",
- e,
- "total",
- val_loss,
- "out",
- mean_out,
- "atom",
- mean_atom,
- "grad",
- mean_grad,
- "stress",
- mean_stress,
- saving_msg,
- )
history_val.append([mean_out, mean_atom, mean_grad, mean_stress])
dumpjson(
filename=os.path.join(config.output_dir, "history_val.json"),
data=history_val,
)
-
- test_loss = 0
- test_result = []
- for dats, jid in zip(test_loader, test_loader.dataset.ids):
- # for dats in test_loader:
- info = {}
- info["id"] = jid
- optimizer.zero_grad()
- # print('dats[0]',dats[0])
- # print('test_loader',test_loader)
- # print('test_loader.dataset.ids',test_loader.dataset.ids)
- result = net([dats[0].to(device), dats[1].to(device)])
- loss1 = 0 # Such as energy
- loss2 = 0 # Such as bader charges
- loss3 = 0 # Such as forces
- loss4 = 0 # Such as stresses
- if config.model.output_features is not None and not classification:
- # print('result["out"]',result["out"])
- # print('dats[2]',dats[2])
- loss1 = config.model.graphwise_weight * criterion(
- result["out"], dats[2].to(device)
+ # print('rank',rank)
+ # print('world_size',world_size)
+ if rank == 0:
+ print(
+ "TrainLoss",
+ "Epoch",
+ e,
+ "total",
+ running_loss,
+ "out",
+ mean_out,
+ "atom",
+ mean_atom,
+ "grad",
+ mean_grad,
+ "stress",
+ mean_stress,
+ "time",
+ train_ep_time,
)
- info["target_out"] = dats[2].cpu().numpy().tolist()
- info["pred_out"] = (
- result["out"].cpu().detach().numpy().tolist()
+ print(
+ "ValLoss",
+ "Epoch",
+ e,
+ "total",
+ val_loss,
+ "out",
+ mean_out,
+ "atom",
+ mean_atom,
+ "grad",
+ mean_grad,
+ "stress",
+ mean_stress,
+ saving_msg,
)
- if config.model.atomwise_output_features > 0:
- loss2 = config.model.atomwise_weight * criterion(
- result["atomwise_pred"].to(device),
- dats[0].ndata["atomwise_target"].to(device),
- )
- info["target_atomwise_pred"] = (
- dats[0].ndata["atomwise_target"].cpu().numpy().tolist()
- )
- info["pred_atomwise_pred"] = (
- result["atomwise_pred"].cpu().detach().numpy().tolist()
- )
+ if rank == 0 or world_size == 1:
+ test_loss = 0
+ test_result = []
+ for dats, jid in zip(test_loader, test_loader.dataset.ids):
+ # for dats in test_loader:
+ info = {}
+ info["id"] = jid
+ optimizer.zero_grad()
+ # print('dats[0]',dats[0])
+ # print('test_loader',test_loader)
+ # print('test_loader.dataset.ids',test_loader.dataset.ids)
+ result = net([dats[0].to(device), dats[1].to(device)])
+ loss1 = 0 # Such as energy
+ loss2 = 0 # Such as bader charges
+ loss3 = 0 # Such as forces
+ loss4 = 0 # Such as stresses
+ if (
+ config.model.output_features is not None
+ and not classification
+ ):
+ # print('result["out"]',result["out"])
+ # print('dats[2]',dats[2])
+ loss1 = config.model.graphwise_weight * criterion(
+ result["out"], dats[2].to(device)
+ )
+ info["target_out"] = dats[2].cpu().numpy().tolist()
+ info["pred_out"] = (
+ result["out"].cpu().detach().numpy().tolist()
+ )
- if config.model.calculate_gradient:
- loss3 = config.model.gradwise_weight * criterion(
- result["grad"].to(device),
- dats[0].ndata["atomwise_grad"].to(device),
- )
- info["target_grad"] = (
- dats[0].ndata["atomwise_grad"].cpu().numpy().tolist()
- )
- info["pred_grad"] = (
- result["grad"].cpu().detach().numpy().tolist()
- )
- if config.model.stresswise_weight != 0:
- loss4 = config.model.stresswise_weight * criterion(
- # torch.flatten(result["stress"].to(device)),
- # (dats[0].ndata["stresses"]).to(device),
- # torch.flatten(dats[0].ndata["stresses"]).to(device),
- result["stresses"].to(device),
- torch.cat(tuple(dats[0].ndata["stresses"])).to(device),
- # torch.flatten(torch.cat(dats[0].ndata["stresses"])).to(device),
- # dats[0].ndata["stresses"][0].to(device),
- )
- # loss4 = config.model.stresswise_weight * criterion(
- # result["stress"][0].to(device),
- # dats[0].ndata["stresses"].to(device),
- # )
- info["target_stress"] = (
- torch.cat(tuple(dats[0].ndata["stresses"]))
- .cpu()
- .numpy()
- .tolist()
- )
- info["pred_stress"] = (
- result["stresses"].cpu().detach().numpy().tolist()
- )
- test_result.append(info)
- loss = loss1 + loss2 + loss3 + loss4
- if not classification:
- test_loss += loss.item()
- print("TestLoss", e, test_loss)
- dumpjson(
- filename=os.path.join(config.output_dir, "Test_results.json"),
- data=test_result,
- )
- last_model_name = "last_model.pt"
- torch.save(
- net.state_dict(),
- os.path.join(config.output_dir, last_model_name),
- )
- # return test_result
-
- if config.write_predictions and classification:
- best_model.eval()
- # net.eval()
- f = open(
- os.path.join(config.output_dir, "prediction_results_test_set.csv"),
- "w",
- )
- f.write("id,target,prediction\n")
- targets = []
- predictions = []
- with torch.no_grad():
- ids = test_loader.dataset.ids # [test_loader.dataset.indices]
- for dat, id in zip(test_loader, ids):
- g, lg, target = dat
- out_data = best_model([g.to(device), lg.to(device)])["out"]
- # out_data = net([g.to(device), lg.to(device)])["out"]
- # out_data = torch.exp(out_data.cpu())
- # print('target',target)
- # print('out_data',out_data)
- top_p, top_class = torch.topk(torch.exp(out_data), k=1)
- target = int(target.cpu().numpy().flatten().tolist()[0])
-
- f.write("%s, %d, %d\n" % (id, (target), (top_class)))
- targets.append(target)
- predictions.append(
- top_class.cpu().numpy().flatten().tolist()[0]
- )
- f.close()
+ if config.model.atomwise_output_features > 0:
+ loss2 = config.model.atomwise_weight * criterion(
+ result["atomwise_pred"].to(device),
+ dats[0].ndata["atomwise_target"].to(device),
+ )
+ info["target_atomwise_pred"] = (
+ dats[0].ndata["atomwise_target"].cpu().numpy().tolist()
+ )
+ info["pred_atomwise_pred"] = (
+ result["atomwise_pred"].cpu().detach().numpy().tolist()
+ )
- print("predictions", predictions)
- print("targets", targets)
- print(
- "Test ROCAUC:",
- roc_auc_score(np.array(targets), np.array(predictions)),
- )
+ if config.model.calculate_gradient:
+ loss3 = config.model.gradwise_weight * criterion(
+ result["grad"].to(device),
+ dats[0].ndata["atomwise_grad"].to(device),
+ )
+ info["target_grad"] = (
+ dats[0].ndata["atomwise_grad"].cpu().numpy().tolist()
+ )
+ info["pred_grad"] = (
+ result["grad"].cpu().detach().numpy().tolist()
+ )
+ if config.model.stresswise_weight != 0:
+ loss4 = config.model.stresswise_weight * criterion(
+ # torch.flatten(result["stress"].to(device)),
+ # (dats[0].ndata["stresses"]).to(device),
+ # torch.flatten(dats[0].ndata["stresses"]).to(device),
+ result["stresses"].to(device),
+ torch.cat(tuple(dats[0].ndata["stresses"])).to(device),
+ # torch.flatten(torch.cat(dats[0].ndata["stresses"])).to(device),
+ # dats[0].ndata["stresses"][0].to(device),
+ )
+ # loss4 = config.model.stresswise_weight * criterion(
+ # result["stress"][0].to(device),
+ # dats[0].ndata["stresses"].to(device),
+ # )
+ info["target_stress"] = (
+ torch.cat(tuple(dats[0].ndata["stresses"]))
+ .cpu()
+ .numpy()
+ .tolist()
+ )
+ info["pred_stress"] = (
+ result["stresses"].cpu().detach().numpy().tolist()
+ )
+ test_result.append(info)
+ loss = loss1 + loss2 + loss3 + loss4
+ if not classification:
+ test_loss += loss.item()
+ print("TestLoss", e, test_loss)
+ dumpjson(
+ filename=os.path.join(config.output_dir, "Test_results.json"),
+ data=test_result,
+ )
+ last_model_name = "last_model.pt"
+ torch.save(
+ net.state_dict(),
+ os.path.join(config.output_dir, last_model_name),
+ )
+ # return test_result
+ if rank == 0 or world_size == 1:
+ if config.write_predictions and classification:
+ best_model.eval()
+ # net.eval()
+ f = open(
+ os.path.join(
+ config.output_dir, "prediction_results_test_set.csv"
+ ),
+ "w",
+ )
+ f.write("id,target,prediction\n")
+ targets = []
+ predictions = []
+ with torch.no_grad():
+ ids = test_loader.dataset.ids # [test_loader.dataset.indices]
+ for dat, id in zip(test_loader, ids):
+ g, lg, target = dat
+ out_data = best_model([g.to(device), lg.to(device)])["out"]
+ # out_data = net([g.to(device), lg.to(device)])["out"]
+ # out_data = torch.exp(out_data.cpu())
+ # print('target',target)
+ # print('out_data',out_data)
+ top_p, top_class = torch.topk(torch.exp(out_data), k=1)
+ target = int(target.cpu().numpy().flatten().tolist()[0])
+
+ f.write("%s, %d, %d\n" % (id, (target), (top_class)))
+ targets.append(target)
+ predictions.append(
+ top_class.cpu().numpy().flatten().tolist()[0]
+ )
+ f.close()
- if (
- config.write_predictions
- and not classification
- and config.model.output_features > 1
- ):
- best_model.eval()
- # net.eval()
- mem = []
- with torch.no_grad():
- ids = test_loader.dataset.ids # [test_loader.dataset.indices]
- for dat, id in zip(test_loader, ids):
- g, lg, target = dat
- out_data = best_model([g.to(device), lg.to(device)])["out"]
- # out_data = net([g.to(device), lg.to(device)])["out"]
- out_data = out_data.cpu().numpy().tolist()
- if config.standard_scalar_and_pca:
- sc = pk.load(open("sc.pkl", "rb"))
- out_data = list(
- sc.transform(np.array(out_data).reshape(1, -1))[0]
- ) # [0][0]
- target = target.cpu().numpy().flatten().tolist()
- info = {}
- info["id"] = id
- info["target"] = target
- info["predictions"] = out_data
- mem.append(info)
- dumpjson(
- filename=os.path.join(
- config.output_dir, "multi_out_predictions.json"
- ),
- data=mem,
- )
- if (
- config.write_predictions
- and not classification
- and config.model.output_features == 1
- and config.model.gradwise_weight == 0
- ):
- best_model.eval()
- # net.eval()
- f = open(
- os.path.join(config.output_dir, "prediction_results_test_set.csv"),
- "w",
- )
- f.write("id,target,prediction\n")
- targets = []
- predictions = []
- with torch.no_grad():
- ids = test_loader.dataset.ids # [test_loader.dataset.indices]
- for dat, id in zip(test_loader, ids):
- g, lg, target = dat
- out_data = best_model([g.to(device), lg.to(device)])["out"]
- # out_data = net([g.to(device), lg.to(device)])["out"]
- out_data = out_data.cpu().numpy().tolist()
- if config.standard_scalar_and_pca:
- sc = pk.load(
- open(os.path.join(tmp_output_dir, "sc.pkl"), "rb")
- )
- out_data = sc.transform(np.array(out_data).reshape(-1, 1))[
- 0
- ][0]
- target = target.cpu().numpy().flatten().tolist()
- if len(target) == 1:
- target = target[0]
- f.write("%s, %6f, %6f\n" % (id, target, out_data))
- targets.append(target)
- predictions.append(out_data)
- f.close()
-
- print(
- "Test MAE:",
- mean_absolute_error(np.array(targets), np.array(predictions)),
- )
- best_model.eval()
- # net.eval()
- f = open(
- os.path.join(
- config.output_dir, "prediction_results_train_set.csv"
- ),
- "w",
- )
- f.write("target,prediction\n")
- targets = []
- predictions = []
- with torch.no_grad():
- ids = train_loader.dataset.ids # [test_loader.dataset.indices]
- for dat, id in zip(train_loader, ids):
- g, lg, target = dat
- out_data = best_model([g.to(device), lg.to(device)])["out"]
- # out_data = net([g.to(device), lg.to(device)])["out"]
- out_data = out_data.cpu().numpy().tolist()
- if config.standard_scalar_and_pca:
- sc = pk.load(
- open(os.path.join(tmp_output_dir, "sc.pkl"), "rb")
- )
- out_data = sc.transform(np.array(out_data).reshape(-1, 1))[
- 0
- ][0]
- target = target.cpu().numpy().flatten().tolist()
- # if len(target) == 1:
- # target = target[0]
- # if len(out_data) == 1:
- # out_data = out_data[0]
- for ii, jj in zip(target, out_data):
- f.write("%6f, %6f\n" % (ii, jj))
- targets.append(ii)
- predictions.append(jj)
- f.close()
+ print("predictions", predictions)
+ print("targets", targets)
+ print(
+ "Test ROCAUC:",
+ roc_auc_score(np.array(targets), np.array(predictions)),
+ )
+
+ if (
+ config.write_predictions
+ and not classification
+ and config.model.output_features > 1
+ ):
+ best_model.eval()
+ # net.eval()
+ mem = []
+ with torch.no_grad():
+ ids = test_loader.dataset.ids # [test_loader.dataset.indices]
+ for dat, id in zip(test_loader, ids):
+ g, lg, target = dat
+ out_data = best_model([g.to(device), lg.to(device)])["out"]
+ # out_data = net([g.to(device), lg.to(device)])["out"]
+ out_data = out_data.cpu().numpy().tolist()
+ if config.standard_scalar_and_pca:
+ sc = pk.load(open("sc.pkl", "rb"))
+ out_data = list(
+ sc.transform(np.array(out_data).reshape(1, -1))[0]
+ ) # [0][0]
+ target = target.cpu().numpy().flatten().tolist()
+ info = {}
+ info["id"] = id
+ info["target"] = target
+ info["predictions"] = out_data
+ mem.append(info)
+ dumpjson(
+ filename=os.path.join(
+ config.output_dir, "multi_out_predictions.json"
+ ),
+ data=mem,
+ )
+ if (
+ config.write_predictions
+ and not classification
+ and config.model.output_features == 1
+ and config.model.gradwise_weight == 0
+ ):
+ best_model.eval()
+ # net.eval()
+ f = open(
+ os.path.join(
+ config.output_dir, "prediction_results_test_set.csv"
+ ),
+ "w",
+ )
+ f.write("id,target,prediction\n")
+ targets = []
+ predictions = []
+ with torch.no_grad():
+ ids = test_loader.dataset.ids # [test_loader.dataset.indices]
+ for dat, id in zip(test_loader, ids):
+ g, lg, target = dat
+ out_data = best_model([g.to(device), lg.to(device)])["out"]
+ # out_data = net([g.to(device), lg.to(device)])["out"]
+ out_data = out_data.cpu().numpy().tolist()
+ if config.standard_scalar_and_pca:
+ sc = pk.load(
+ open(os.path.join(tmp_output_dir, "sc.pkl"), "rb")
+ )
+ out_data = sc.transform(
+ np.array(out_data).reshape(-1, 1)
+ )[0][0]
+ target = target.cpu().numpy().flatten().tolist()
+ if len(target) == 1:
+ target = target[0]
+ f.write("%s, %6f, %6f\n" % (id, target, out_data))
+ targets.append(target)
+ predictions.append(out_data)
+ f.close()
+
+ print(
+ "Test MAE:",
+ mean_absolute_error(np.array(targets), np.array(predictions)),
+ )
+ best_model.eval()
+ # net.eval()
+ f = open(
+ os.path.join(
+ config.output_dir, "prediction_results_train_set.csv"
+ ),
+ "w",
+ )
+ f.write("target,prediction\n")
+ targets = []
+ predictions = []
+ with torch.no_grad():
+ ids = train_loader.dataset.ids # [test_loader.dataset.indices]
+ for dat, id in zip(train_loader, ids):
+ g, lg, target = dat
+ out_data = best_model([g.to(device), lg.to(device)])["out"]
+ # out_data = net([g.to(device), lg.to(device)])["out"]
+ out_data = out_data.cpu().numpy().tolist()
+ if config.standard_scalar_and_pca:
+ sc = pk.load(
+ open(os.path.join(tmp_output_dir, "sc.pkl"), "rb")
+ )
+ out_data = sc.transform(
+ np.array(out_data).reshape(-1, 1)
+ )[0][0]
+ target = target.cpu().numpy().flatten().tolist()
+ # if len(target) == 1:
+ # target = target[0]
+ # if len(out_data) == 1:
+ # out_data = out_data[0]
+ for ii, jj in zip(target, out_data):
+ f.write("%6f, %6f\n" % (ii, jj))
+ targets.append(ii)
+ predictions.append(jj)
+ f.close()
+ if config.use_lmdb:
+ print("Closing LMDB.")
+ train_loader.dataset.close()
+ val_loader.dataset.close()
+ test_loader.dataset.close()
if __name__ == "__main__":
diff --git a/alignn/train_alignn.py b/alignn/train_alignn.py
index 79f78961..2e058951 100644
--- a/alignn/train_alignn.py
+++ b/alignn/train_alignn.py
@@ -2,6 +2,7 @@
"""Module to train for a folder with formatted dataset."""
import os
+import torch.distributed as dist
import csv
import sys
import json
@@ -15,12 +16,33 @@
import torch
import time
from jarvis.core.atoms import Atoms
+import random
device = "cpu"
if torch.cuda.is_available():
device = torch.device("cuda")
+def setup(rank=0, world_size=0, port="12356"):
+ """Set up multi GPU rank."""
+ # "12356"
+ if port == "":
+ port = str(random.randint(10000, 99999))
+ if world_size > 1:
+ os.environ["MASTER_ADDR"] = "localhost"
+ os.environ["MASTER_PORT"] = port
+ # os.environ["MASTER_PORT"] = "12355"
+ # Initialize the distributed environment.
+ dist.init_process_group("nccl", rank=rank, world_size=world_size)
+ torch.cuda.set_device(rank)
+
+
+def cleanup(world_size):
+ """Clean up distributed process."""
+ if world_size > 1:
+ dist.destroy_process_group()
+
+
parser = argparse.ArgumentParser(
description="Atomistic Line Graph Neural Network"
)
@@ -115,9 +137,10 @@
def train_for_folder(
+ rank=0,
+ world_size=0,
root_dir="examples/sample_data",
config_name="config.json",
- # keep_data_order=False,
classification_threshold=None,
batch_size=None,
epochs=None,
@@ -128,11 +151,11 @@ def train_for_folder(
stresswise_key="stresses",
file_format="poscar",
restart_model_path=None,
- # subtract_mean=False,
- # normalize_with_natoms=False,
output_dir=None,
):
"""Train for a folder."""
+ setup(rank=rank, world_size=world_size)
+ print("root_dir", root_dir)
id_prop_json = os.path.join(root_dir, "id_prop.json")
id_prop_json_zip = os.path.join(root_dir, "id_prop.json.zip")
id_prop_csv = os.path.join(root_dir, "id_prop.csv")
@@ -369,9 +392,13 @@ def train_for_folder(
standard_scalar_and_pca=config.standard_scalar_and_pca,
keep_data_order=config.keep_data_order,
output_dir=config.output_dir,
+ use_lmdb=config.use_lmdb,
)
# print("dataset", dataset[0])
t1 = time.time()
+ # world_size = torch.cuda.device_count()
+ print("rank", rank)
+ print("world_size", world_size)
train_dgl(
config,
model=model,
@@ -381,6 +408,8 @@ def train_for_folder(
test_loader,
prepare_batch,
],
+ rank=rank,
+ world_size=world_size,
)
t2 = time.time()
print("Time taken (s)", t2 - t1)
@@ -390,21 +419,48 @@ def train_for_folder(
if __name__ == "__main__":
args = parser.parse_args(sys.argv[1:])
- train_for_folder(
- root_dir=args.root_dir,
- config_name=args.config_name,
- # keep_data_order=args.keep_data_order,
- classification_threshold=args.classification_threshold,
- output_dir=args.output_dir,
- batch_size=(args.batch_size),
- epochs=(args.epochs),
- target_key=(args.target_key),
- id_key=(args.id_key),
- atomwise_key=(args.atomwise_key),
- gradwise_key=(args.force_key),
- stresswise_key=(args.stresswise_key),
- restart_model_path=(args.restart_model_path),
- # subtract_mean=(args.subtract_mean),
- # normalize_with_natoms=(args.normalize_with_natoms),
- file_format=(args.file_format),
- )
+ world_size = int(torch.cuda.device_count())
+ print("world_size", world_size)
+ if world_size > 1:
+ torch.multiprocessing.spawn(
+ train_for_folder,
+ args=(
+ world_size,
+ args.root_dir,
+ args.config_name,
+ args.classification_threshold,
+ args.batch_size,
+ args.epochs,
+ args.id_key,
+ args.target_key,
+ args.atomwise_key,
+ args.force_key,
+ args.stresswise_key,
+ args.file_format,
+ args.restart_model_path,
+ args.output_dir,
+ ),
+ nprocs=world_size,
+ )
+ else:
+ train_for_folder(
+ 0,
+ world_size,
+ args.root_dir,
+ args.config_name,
+ args.classification_threshold,
+ args.batch_size,
+ args.epochs,
+ args.id_key,
+ args.target_key,
+ args.atomwise_key,
+ args.force_key,
+ args.stresswise_key,
+ args.file_format,
+ args.restart_model_path,
+ args.output_dir,
+ )
+ try:
+ cleanup(world_size)
+ except Exception:
+ pass
diff --git a/environment.yml b/environment.yml
index c7460d58..886addb5 100644
--- a/environment.yml
+++ b/environment.yml
@@ -168,6 +168,7 @@ dependencies:
- pysocks=1.7.1=pyha2e5f31_6
- pytest=8.1.1=pyhd8ed1ab_0
- python=3.10.13=hd12c33a_0_cpython
+ - python-lmdb=1.4.1=py310hdf73078_1
- python-tzdata=2024.1=pyhd8ed1ab_0
- python_abi=3.10=4_cp310
- pytz=2024.1=pyhd8ed1ab_0
diff --git a/setup.py b/setup.py
index e075929c..c3f078f3 100644
--- a/setup.py
+++ b/setup.py
@@ -10,7 +10,7 @@
setuptools.setup(
name="alignn",
- version="2024.4.10",
+ version="2024.4.20",
author="Kamal Choudhary, Brian DeCost",
author_email="kamal.choudhary@nist.gov",
description="alignn",
@@ -33,6 +33,7 @@
"pydocstyle>=6.0.0",
"pyparsing>=2.2.1,<3",
"ase",
+ "lmdb",
# "pytorch-ignite>=0.5.0.dev20221024",
# "accelerate>=0.20.3",
# "dgl-cu101>=0.6.0",