Skip to content

Commit

Permalink
finishing touches
Browse files Browse the repository at this point in the history
adding more ndarralike class
  • Loading branch information
Rosalbam1 committed Aug 26, 2024
1 parent 0896e29 commit c4229fa
Showing 1 changed file with 12 additions and 12 deletions.
24 changes: 12 additions & 12 deletions automol/embed/_cleanup.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
X = numpy.newaxis


def volume(xmat, idxs):
def volume(xmat:NDArrayLike2D, idxs:list):
"""Calculate signed tetrahedral volume for a tetrad of atoms.
for a tetrad of four atoms (1, 2, 3, 4) around a central atom, the signed
Expand All @@ -71,7 +71,7 @@ def volume(xmat, idxs):
return vol


def volume_gradient(xmat, idxs):
def volume_gradient(xmat:NDArrayLike2D, idxs:list):
"""Calculate the tetrahedral volume gradient for a tetrad of atoms."""
xmat = numpy.array(xmat)
idxs = list(idxs)
Expand Down Expand Up @@ -122,7 +122,7 @@ def error_function_(
pla_dct = {} if pla_dct is None else pla_dct
chip_dct = {**chi_dct, **pla_dct}

def _function(xmat):
def _function(xmat:NDArrayLike2D):
dmat = distance_matrix_from_coordinates(xmat)

# distance error (equation 61 in the paper referenced above)
Expand Down Expand Up @@ -259,8 +259,8 @@ def _gradient(xmat):
def error_function_numerical_gradient_(
lmat: NDArrayLike2D,
umat: NDArrayLike2D,
chi_dct=None,
pla_dct=None,
chi_dct:SignedVolumeContraints=None,
pla_dct:SignedVolumeContraints=None,
wdist=1.0,
wchip=1.0,
wdim4=1.0,
Expand Down Expand Up @@ -316,11 +316,11 @@ def _function_of_alpha(alpha):


def cleaned_up_coordinates(
xmat,
lmat,
umat,
chi_dct=None,
pla_dct=None,
xmat:NDArrayLike2D,
lmat:NDArrayLike2D,
umat:NDArrayLike2D,
chi_dct:SignedVolumeContraints=None,
pla_dct:SignedVolumeContraints=None,
conv_=None,
max_dist_err=0.2,
grad_thresh=0.2,
Expand Down Expand Up @@ -401,7 +401,7 @@ def _is_converged(xmat, err, grad):
return _is_converged


def distance_convergence_checker_(lmat, umat, max_dist_err=0.2):
def distance_convergence_checker_(lmat:NDArrayLike2D, umat:NDArrayLike2D, max_dist_err=0.2):
"""Convergence checker based on the maximum distance error."""

def _is_converged(xmat, err, grad):
Expand Down Expand Up @@ -448,7 +448,7 @@ def _is_converged(xmat, err, grad):
return _is_converged


def minimize_error(xmat, err_, grad_, conv_, maxiter=None):
def minimize_error(xmat:NDArrayLike2D, err_, grad_, conv_, maxiter=None):
"""Do conjugate-gradients error minimization.
:param err_: a callable error function of xmat
Expand Down

0 comments on commit c4229fa

Please sign in to comment.