-
Notifications
You must be signed in to change notification settings - Fork 26
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #60 from sunqm/trexio-itrf
Trexio interface
- Loading branch information
Showing
3 changed files
with
257 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
import pyscf | ||
from pyscf.tools import trexio | ||
|
||
def test_mol(): | ||
filename = 'test.h5' | ||
mol = pyscf.M(atom='H 0 0 0; F 0 0 1', basis='6-31g**') | ||
ref = mol.intor('int1e_ovlp') | ||
trexio.to_trexio(mol, filename) | ||
mol = trexio.mol_from_trexio(filename) | ||
s1 = mol.intor('int1e_ovlp') | ||
assert abs(ref - s1).max() < 1e-12 | ||
|
||
mol = pyscf.M(atom='H 0 0 0; F 0 0 1', basis='6-31g**', cart=True) | ||
ref = mol.intor('int1e_ovlp') | ||
trexio.to_trexio(mol, filename) | ||
mol = trexio.mol_from_trexio(filename) | ||
s1 = mol.intor('int1e_ovlp') | ||
assert abs(ref - s1).max() < 1e-12 | ||
|
||
def test_mf(): | ||
filename = 'test1.h5' | ||
mol = pyscf.M(atom='H 0 0 0; F 0 0 1', basis='6-31g*') | ||
mf = mol.RHF().run() | ||
trexio.to_trexio(mf, filename) | ||
mf1 = trexio.scf_from_trexio(filename) | ||
assert abs(mf1.mo_coeff - mf.mo_coeff).max() < 1e-12 | ||
|
||
trexio.write_eri(mf._eri, filename) | ||
eri = trexio.read_eri(filename) | ||
assert abs(mf._eri - eri).max() < 1e-12 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,225 @@ | ||
#!/usr/bin/env python | ||
|
||
''' | ||
TREX-IO interface | ||
References: | ||
https://trex-coe.github.io/trexio/trex.html | ||
https://github.com/TREX-CoE/trexio-tutorials/blob/ad5c60aa6a7bca802c5918cef2aeb4debfa9f134/notebooks/tutorial_benzene.md | ||
Installation instruction: | ||
https://github.com/TREX-CoE/trexio/blob/master/python/README.md | ||
''' | ||
|
||
import numpy as np | ||
from pyscf import lib | ||
from pyscf import gto | ||
from pyscf import scf | ||
import trexio | ||
|
||
def to_trexio(obj, filename, backend='h5'): | ||
with trexio.File(filename, 'u', back_end=_mode(backend)) as tf: | ||
if isinstance(obj, gto.Mole): | ||
_mol_to_trexio(obj, tf) | ||
elif isinstance(obj, scf.hf.SCF): | ||
_scf_to_trexio(obj, tf) | ||
else: | ||
raise NotImplementedError(f'Conversion function for {obj.__class__}') | ||
|
||
def _mol_to_trexio(mol, trexio_file): | ||
trexio.write_nucleus_num(trexio_file, mol.natm) | ||
|
||
labels = [mol.atom_pure_symbol(i) for i in range(mol.natm)] | ||
trexio.write_nucleus_label(trexio_file, labels) | ||
if mol._ecp: | ||
raise NotImplementedError | ||
trexio.write_nucleus_charge(trexio_file, mol.atom_charges()) | ||
trexio.write_nucleus_coord(trexio_file, mol.atom_coords()) | ||
|
||
trexio.write_basis_type(trexio_file, 'Gaussian') | ||
trexio.write_basis_shell_num(trexio_file, mol.nbas) | ||
trexio.write_basis_prim_num(trexio_file, int(mol._bas[:,gto.NPRIM_OF].sum())) | ||
trexio.write_basis_nucleus_index(trexio_file, mol._bas[:,gto.ATOM_OF]) | ||
trexio.write_basis_shell_ang_mom(trexio_file, mol._bas[:,gto.ANG_OF]) | ||
trexio.write_basis_shell_factor(trexio_file, np.ones(mol.nbas)) | ||
prim2sh = [[ib]*nprim for ib, nprim in enumerate(mol._bas[:,gto.NPRIM_OF])] | ||
trexio.write_basis_shell_index(trexio_file, np.hstack(prim2sh)) | ||
trexio.write_basis_exponent(trexio_file, np.hstack(mol.bas_exps())) | ||
assert all(mol._bas[:,gto.NCTR_OF] == 1) | ||
coef = [mol.bas_ctr_coeff(i).ravel() for i in range(mol.nbas)] | ||
trexio.write_basis_coefficient(trexio_file, np.hstack(coef)) | ||
prim_norms = [gto.gto_norm(mol.bas_angular(i), mol.bas_exp(i)) | ||
for i in range(mol.nbas)] | ||
trexio.write_basis_prim_factor(trexio_file, np.hstack(prim_norms)) | ||
|
||
if mol.symmetry: | ||
trexio.write_nucleus_point_group(trexio_file, mol.groupname) | ||
|
||
if mol.cart: | ||
trexio.write_ao_cartesian(trexio_file, 1) | ||
else: | ||
trexio.write_ao_cartesian(trexio_file, 0) | ||
|
||
def _scf_to_trexio(mf, trexio_file): | ||
mol = mf.mol | ||
_mol_to_trexio(mol, trexio_file) | ||
if isinstance(mf, scf.uhf.UHF): | ||
mo_energy = np.ravel(mf.mo_energy) | ||
mo = np.hstack(*mf.mo_coeff) | ||
mo_occ = np.ravel(mf.mo_occ) | ||
spin = np.zeros(mo_energy.size, dtype=int) | ||
spin[mf.mo_energy[0].size:] = 1 | ||
mo_type = 'UHF' | ||
else: | ||
mo_energy = mf.mo_energy | ||
mo = mf.mo_coeff | ||
mo_occ = mf.mo_occ | ||
spin = np.zeros(mo_energy.size, dtype=int) | ||
mo_type = 'RHF' | ||
trexio.write_mo_type(trexio_file, mo_type) | ||
idx = _order_ao_index(mf.mol) | ||
trexio.write_ao_num(trexio_file, int(mol.nao)) | ||
trexio.write_mo_num(trexio_file, mo_energy.size) | ||
trexio.write_mo_coefficient(trexio_file, mo[idx].T.ravel()) | ||
trexio.write_mo_energy(trexio_file, mo_energy) | ||
trexio.write_mo_occupation(trexio_file, mo_occ) | ||
trexio.write_mo_spin(trexio_file, spin) | ||
|
||
def _cc_to_trexio(cc_obj, trexio_file): | ||
raise NotImplementedError | ||
|
||
def _mcscf_to_trexio(cas_obj, trexio_file): | ||
raise NotImplementedError | ||
|
||
def mol_from_trexio(filename, backend='h5'): | ||
mol = gto.Mole() | ||
with trexio.File(filename, 'r', back_end=_mode(backend)) as tf: | ||
assert trexio.read_basis_type(tf) == 'Gaussian' | ||
if trexio.has_ecp(tf): | ||
raise NotImplementedError | ||
labels = trexio.read_nucleus_label(tf) | ||
coords = trexio.read_nucleus_coord(tf) | ||
elements = [s+str(i) for i, s in enumerate(labels)] | ||
mol.atom = list(zip(elements, coords)) | ||
mol.unit = 'Bohr' | ||
if trexio.has_ao_cartesian(tf): | ||
mol.cart = trexio.read_ao_cartesian(tf) == 1 | ||
|
||
if trexio.has_nucleus_point_group(tf): | ||
mol.symmetry = trexio.read_nucleus_point_group(tf) | ||
|
||
nuc_idx = trexio.read_basis_nucleus_index(tf) | ||
ls = trexio.read_basis_shell_ang_mom(tf) | ||
prim2sh = trexio.read_basis_shell_index(tf) | ||
exps = trexio.read_basis_exponent(tf) | ||
coef = trexio.read_basis_coefficient(tf) | ||
|
||
basis = {} | ||
exps = _group_by(exps, prim2sh) | ||
coef = _group_by(coef, prim2sh) | ||
p1 = 0 | ||
for ia, at_ls in enumerate(_group_by(ls, nuc_idx)): | ||
p0, p1 = p1, p1 + at_ls.size | ||
at_basis = [[l, *zip(e, c)] | ||
for l, e, c in zip(ls[p0:p1], exps[p0:p1], coef[p0:p1])] | ||
basis[elements[ia]] = at_basis | ||
|
||
# To avoid the mol.build() sort the basis, disable mol.basis and set the | ||
# internal data _basis directly. | ||
mol.basis = {} | ||
mol._basis = basis | ||
return mol.build() | ||
|
||
def scf_from_trexio(filename, backend='h5'): | ||
mol = mol_from_trexio(filename, backend) | ||
with trexio.File(filename, 'r', back_end=_mode(backend)) as tf: | ||
mo_energy = trexio.read_mo_energy(tf) | ||
mo = trexio.read_mo_coefficient(tf) | ||
mo_occ = trexio.read_mo_occupation(tf) | ||
|
||
nao = mol.nao | ||
assert mo.shape == (nao, nao) # RHF only | ||
mf = mol.RHF() | ||
mf.mo_coeff = np.empty_like(mo).T | ||
idx = _order_ao_index(mol) | ||
mf.mo_coeff[idx] = mo.T | ||
mf.mo_energy = mo_energy | ||
mf.mo_occ = mo_occ | ||
return mf | ||
|
||
def write_eri(eri, filename, backend='h5'): | ||
num_integrals = eri.size | ||
if eri.ndim == 4: | ||
n = eri.shape[0] | ||
idx = lib.cartesian_prod([np.arange(n, dtype=np.int32)]*4) | ||
elif eri.ndim == 2: # 4-fold symmetry | ||
npair = eri.shape[0] | ||
n = int((npair * 2)**.5) | ||
idx_pair = np.argwhere(np.arange(n)[:,None] >= np.arange(n)) | ||
idx = np.empty((npair, npair, 4), dtype=np.int32) | ||
idx[:,:,:2] = idx_pair[:,None,:] | ||
idx[:,:,2:] = idx_pair[None,:,:] | ||
idx = idx.reshape(npair**2, 4) | ||
elif eri.ndim == 1: # 8-fold symmetry | ||
npair = int((eri.shape[0] * 2)**.5) | ||
n = int((npair * 2)**.5) | ||
idx_pair = np.argwhere(np.arange(n)[:,None] >= np.arange(n)) | ||
idx = np.empty((npair, npair, 4), dtype=np.int32) | ||
idx[:,:,:2] = idx_pair[:,None,:] | ||
idx[:,:,2:] = idx_pair[None,:,:] | ||
idx = idx[np.tril_indices(npair)] | ||
|
||
with trexio.File(filename, 'w', back_end=_mode(backend)) as tf: | ||
trexio.write_mo_2e_int_eri(tf, 0, num_integrals, idx, eri.ravel()) | ||
|
||
def read_eri(filename, backend='h5'): | ||
'''Read ERIs in AO basis, 8-fold symmetry is assumed''' | ||
with trexio.File(filename, 'r', back_end=_mode(backend)) as tf: | ||
nmo = trexio.read_mo_num(tf) | ||
nao_pair = nmo * (nmo+1) // 2 | ||
eri_size = nao_pair * (nao_pair+1) // 2 | ||
idx, data, n_read, eof_flag = trexio.read_mo_2e_int_eri(tf, 0, eri_size) | ||
eri = np.zeros(eri_size) | ||
x = idx[:,0]*(idx[:,0]+1)//2 + idx[:,1] | ||
y = idx[:,2]*(idx[:,2]+1)//2 + idx[:,3] | ||
eri[x*(x+1)//2+y] = data | ||
return eri | ||
|
||
def _order_ao_index(mol): | ||
if mol.cart: | ||
return np.arange(mol.nao) | ||
|
||
# reorder spherical functions to | ||
# Pz, Px, Py | ||
# D 0, D+1, D-1, D+2, D-2 | ||
# F 0, F+1, F-1, F+2, F-2, F+3, F-3 | ||
# G 0, G+1, G-1, G+2, G-2, G+3, G-3, G+4, G-4 | ||
# ... | ||
lmax = mol._bas[:,gto.ANG_OF].max() | ||
cache_by_l = [np.array([0]), np.array([2, 0, 1])] | ||
for l in range(2, lmax + 1): | ||
idx = np.empty(l*2+1, dtype=int) | ||
idx[ ::2] = l - np.arange(0, l+1) | ||
idx[1::2] = np.arange(l+1, l+l+1) | ||
cache_by_l.append(idx) | ||
|
||
idx = [] | ||
off = 0 | ||
for ib in range(mol.nbas): | ||
l = mol.bas_angular(ib) | ||
for n in range(mol.bas_nctr(ib)): | ||
idx.append(cache_by_l[l] + off) | ||
off += l * 2 + 1 | ||
return np.hstack(idx) | ||
|
||
def _mode(code): | ||
if code == 'h5': | ||
return 0 | ||
else: # trexio text | ||
return 1 | ||
|
||
def _group_by(a, keys): | ||
'''keys must be sorted''' | ||
assert all(keys[:-1] <= keys[1:]) | ||
idx = np.unique(keys, return_index=True)[1] | ||
return np.split(a, idx[1:]) |