Skip to content

Commit

Permalink
Update README.md
Browse files Browse the repository at this point in the history
  • Loading branch information
PabloAMC committed Sep 14, 2023
1 parent 0a95b34 commit b5b9ca1
Showing 1 changed file with 39 additions and 33 deletions.
72 changes: 39 additions & 33 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
<div align="center">

# Grad-DFT: a software library for machine learning density functional theory
# Grad-DFT: a software library for machine learning enhanced density functional theory

[![arXiv](http://img.shields.io/badge/arXiv-2101.10279-B31B1B.svg "Grad-DFT")](https://arxiv.org/abs/2101.10279) ![License](https://img.shields.io/badge/License-Apache%202.0-9F9F9F "https://github.com/XanaduAI/DiffDFT/blob/main/LICENSE")
[![build](https://img.shields.io/badge/build-passing-graygreen.svg "https://github.com/XanaduAI/GradDFT/actions")](https://github.com/XanaduAI/GradDFT/actions) ![arXiv](http://img.shields.io/badge/arXiv-2101.10279-B31B1B.svg "Grad-DFT") ![License](https://img.shields.io/badge/License-Apache%202.0-9F9F9F "https://github.com/XanaduAI/GradDFT/blob/main/LICENSE")

</div>

Expand Down Expand Up @@ -33,41 +33,19 @@ mf.kernel()
molecule = molecule_from_pyscf(mf)
```

### Creating a simple functional

Then we can create a `Functional`.

```python
from jax import numpy as jnp
from grad_dft.functional import Functional

def energy_densities(molecule):
rho = molecule.density()
lda_e = -3/2 * (3/(4*jnp.pi))**(1/3) * (rho**(4/3)).sum(axis = 0, keepdims = True)
return lda_e

coefficients = lambda self, rho: jnp.array([[1.]])

LDA = Functional(coefficients, energy_densities)
```

We can use the functional to compute the predicted energy, where `params` stand for the $\theta$ parameters in the equation above.

```python
from flax.core import freeze

params = freeze({'params': {}})
predicted_energy = LDA.energy(params, molecule)
```

### A more complex neural functional
### Creating a neural functional

A more complex, neural functional can be created as

```python
from jax.nn import sigmoid, gelu
from jax.random import PRNGKey
from flax import linen as nn
from grad_dft.functional import NeuralFunctional
from optax import adam, apply_updates
from tqdm import tqdm
from grad_dft.train import molecule_predictor
from grad_dft.functional import NeuralFunctional, default_loss
from grad_dft.interface import molecule_from_pyscf

def coefficient_inputs(molecule):
rho = jnp.clip(molecule.density(), a_min = 1e-30)
Expand All @@ -85,15 +63,37 @@ neuralfunctional = NeuralFunctional(coefficients, energy_densities, coefficient_
with the corresponding energy calculation

```python
from jax.random import PRNGKey

key = PRNGKey(42)
cinputs = coefficient_inputs(molecule)
params = neuralfunctional.init(key, cinputs)

predicted_energy = neuralfunctional.energy(params, molecule)
```

### Training the neural functional

```python
# Defining training parameters
learning_rate = 1e-5
momentum = 0.9
tx = adam(learning_rate=learning_rate, b1=momentum)
opt_state = tx.init(params)

# and implement the optimization loop
n_epochs = 20
molecule_predict = molecule_predictor(neuralfunctional)
for iteration in tqdm(range(n_epochs), desc="Training epoch"):
(cost_value, predicted_energy), grads = default_loss(
params, molecule_predict, HH_molecule, ground_truth_energy
)
print("Iteration", iteration, "Predicted energy:", predicted_energy, "Cost value:", cost_value)
updates, opt_state = tx.update(grads, opt_state, params)
params = apply_updates(params, updates)

# Save checkpoint
neuralfunctional.save_checkpoints(params, tx, step=n_epochs)
```

## Install

A core dependency of Grad-DFT is [PySCF](https://pyscf.org). To successfully install this package in the forthcoming installion with `pip`, please ensure that `cmake` is installed and that
Expand All @@ -118,6 +118,12 @@ pip install -e ".[examples]"

to install the additional dependencies.

## Acknowledgements

We thank helpful comments and insights from Alain Delgado, Modjtaba Shokrian Zini, Stepan Fomichev, Soran Jahangiri, Diego Guala, Jay Soni, Utkarsh Azad, Vincent Michaud-Rioux, Maria Schuld and Nathan Wiebe.

GradDFT often follows similar calculations and naming conventions as PySCF, though adapted for our purposes. Only a few non-jittable DIIS procedures were directly taken from it. Where this happens, it has been conveniently referenced in the documentation. The test were also implemented against PySCF results.

## Bibtex

```
Expand Down

0 comments on commit b5b9ca1

Please sign in to comment.