diff --git a/automol/vmat.py b/automol/vmat.py index 3b29a74b..f83fb876 100644 --- a/automol/vmat.py +++ b/automol/vmat.py @@ -12,7 +12,7 @@ from .util import ZmatConv, dict_, zmat_conv Symbol = str -Key = str | None +Key = int | None Name = str | None KeyRow = tuple[Key, Key, Key] NameRow = tuple[Name, Name, Name] @@ -218,7 +218,7 @@ def coordinates( return coo_dct -def distance_coordinates(vma: VMatrix) -> dict[str, tuple[int, ...]]: +def distance_coordinates(vma: VMatrix) -> dict[Name, CoordinateKey]: """Get the distance coordinates by coordinate name. :param vma: V-Matrix @@ -227,7 +227,7 @@ def distance_coordinates(vma: VMatrix) -> dict[str, tuple[int, ...]]: return dict_.by_key(coordinates(vma, multi=False), distance_names(vma)) -def central_angle_coordinates(vma: VMatrix) -> dict[str, tuple[int, ...]]: +def central_angle_coordinates(vma: VMatrix) -> dict[Name, CoordinateKey]: """Get the central angle coordinates by coordinate name. :param vma: V-Matrix @@ -236,7 +236,7 @@ def central_angle_coordinates(vma: VMatrix) -> dict[str, tuple[int, ...]]: return dict_.by_key(coordinates(vma, multi=False), central_angle_names(vma)) -def dihedral_angle_coordinates(vma: VMatrix) -> dict[str, tuple[int, ...]]: +def dihedral_angle_coordinates(vma: VMatrix) -> dict[Name, CoordinateKey]: """Get the dihedral angle coordinates by coordinate name. :param vma: V-Matrix @@ -293,7 +293,7 @@ def distance_names(vma: VMatrix) -> tuple[str]: return tuple(more_itertools.unique_everseen(name_mat[1:, 0])) -def central_angle_names(vma: VMatrix) -> tuple[str]: +def central_angle_names(vma: VMatrix) -> tuple[Name, ...]: """Obtain names of all central-angle coordinates defined in the V-Matrix. :param vma: V-Matrix @@ -304,7 +304,7 @@ def central_angle_names(vma: VMatrix) -> tuple[str]: return tuple(more_itertools.unique_everseen(name_mat[2:, 1])) -def dihedral_angle_names(vma: VMatrix) -> tuple[str]: +def dihedral_angle_names(vma: VMatrix) -> tuple[Name, ...]: """Obtain names of all dihedral angle coordinates defined in the V-Matrix. :param vma: V-Matrix @@ -315,7 +315,7 @@ def dihedral_angle_names(vma: VMatrix) -> tuple[str]: return tuple(more_itertools.unique_everseen(name_mat[3:, 2])) -def angle_names(vma: VMatrix) -> tuple[str]: +def angle_names(vma: VMatrix) -> tuple[Name, ...]: """Obtain names of all angle coordinates defined in the V-Matrix. :param vma: V-Matrix @@ -358,7 +358,7 @@ def standard_names(vma: VMatrix, shift: int = 0) -> dict[Name, Name]: return name_dct -def standard_name_matrix(vma: VMatrix, shift: int = 0) -> tuple[tuple[str]]: +def standard_name_matrix(vma: VMatrix, shift: int = 0) -> NameMatrix: """Build a name matrix of the V-Matrix where all of the coordinate names have been standardized: RN: (1<=N<=Ncoords) @@ -366,10 +366,8 @@ def standard_name_matrix(vma: VMatrix, shift: int = 0) -> tuple[tuple[str]]: DN: (1<=N<=Ncoords). :param vma: V-Matrix - :type vma: automol V-Matrix data structure :param shift: value to shift the keys by when obtaining the keys - :type shift: int - :rtype: tuple(tuple(str)) + :return: Name matrix """ natms = count(vma) @@ -383,7 +381,7 @@ def standard_name_matrix(vma: VMatrix, shift: int = 0) -> tuple[tuple[str]]: return name_mat -def distance_coordinate_name(zma: VMatrix, key1: int, key2: int) -> str: +def distance_coordinate_name(zma: VMatrix, key1: int, key2: int) -> Name: """Get the name of a distance coordinate for a given bond. :param zma: the z-matrix @@ -402,7 +400,9 @@ def distance_coordinate_name(zma: VMatrix, key1: int, key2: int) -> str: return name -def central_angle_coordinate_name(zma: VMatrix, key1: int, key2: int, key3: int) -> str: +def central_angle_coordinate_name( + zma: VMatrix, key1: int, key2: int, key3: int +) -> Name: """Get the name of angle coordinate for a set of 3 atoms. :param zma: The z-matrix @@ -424,7 +424,7 @@ def central_angle_coordinate_name(zma: VMatrix, key1: int, key2: int, key3: int) def dihedral_angle_coordinate_name( zma: VMatrix, key1: int, key2: int, key3: int, key4: int -): +) -> Name: """Get the name of dihedral coordinate for a set of 4 atoms. :param zma: The z-matrix @@ -451,7 +451,7 @@ def dihedral_angle_coordinate_name( # # dummy atom functions -def dummy_keys(zma: VMatrix) -> tuple[int]: +def dummy_keys(zma: VMatrix) -> tuple[Key, ...]: """Obtain keys to dummy atoms in the Z-Matrix. :param zma: Z-Matrix @@ -461,7 +461,7 @@ def dummy_keys(zma: VMatrix) -> tuple[int]: return keys -def dummy_coordinate_names(vma: VMatrix) -> tuple[str]: +def dummy_coordinate_names(vma: VMatrix) -> tuple[Name, ...]: """Obtain names of all coordinates associated with dummy atoms defined in the V-Matrix. @@ -484,16 +484,15 @@ def dummy_coordinate_names(vma: VMatrix) -> tuple[str]: return dummy_names -def dummy_source_dict(zma, dir_: bool = True): +def dummy_source_dict( + zma: VMatrix, dir_: bool = True +) -> dict[int, int] | tuple[int, int]: """Obtain keys to dummy atoms in the Z-Matrix, along with their parent atoms. :param zma: Z-Matrix - :type zma: automol Z-Matrix data structure :param dir_: Include linear direction atoms? defaults to True - :type dir_: bool, optional :returns: A dictionary mapping dummy atoms onto their parent atoms - :rtype: dict[int: int] """ key_mat = key_matrix(zma) dum_keys = dummy_keys(zma) @@ -513,14 +512,12 @@ def dummy_source_dict(zma, dir_: bool = True): return src_dct -def conversion_info(zma) -> ZmatConv: +def conversion_info(zma: VMatrix) -> ZmatConv: """Get the conversion information for this z-matrix, relative to geometry following the same atom order. :param zma: Z-Matrix - :type zma: automol Z-Matrix data structure :return: The z-matrix conversion - :rtype: ZmatConv """ zcount = count(zma) src_zkeys_dct = dummy_source_dict(zma) @@ -529,14 +526,12 @@ def conversion_info(zma) -> ZmatConv: # # V-Matrix-specific functions # # # setters -def set_key_matrix(vma, key_mat): +def set_key_matrix(vma: VMatrix, key_mat: KeyMatrix) -> VMatrix: """Re-set the key matrix of a V-Matrix using the input key matrix. :param vma: V-Matrix - :type vma: automol V-Matrix data structure - :param key_mat: key matrix of V-Matrix coordinate keys - :type key_mat: tuple(tuple(int)) - :rtype: tuple(str) + :param key_mat: Key matrix of V-Matrix coordinate keys + :return: VMatrix with input keys """ symbs = symbols(vma) name_mat = name_matrix(vma) @@ -545,14 +540,12 @@ def set_key_matrix(vma, key_mat): return vma -def set_name_matrix(vma, name_mat): +def set_name_matrix(vma: VMatrix, name_mat: NameMatrix) -> VMatrix: """Re-set the name matrix of a V-Matrix using the input name matrix. :param vma: V-Matrix - :type vma: automol V-Matrix data structure :param name_mat: name matrix of V-Matrix coordinate names - :type name_mat: tuple(tuple(int)) - :rtype: tuple(str) + :return: V-Matrix reset """ symbs = symbols(vma) key_mat = key_matrix(vma) @@ -562,14 +555,12 @@ def set_name_matrix(vma, name_mat): # # # names and naming -def rename(vma, name_dct): +def rename(vma: VMatrix, name_dct: dict[str, str]) -> VMatrix: """Rename a subset of the coordinates of a V-Matrix. :param vma: V-Matrix - :type vma: automol V-Matrix data structure - :param name_dct: mapping from old coordinate names to new ones - :type name_dct: dict[str: str] - :rtype: automol V-Matrix data strucutre + :param name_dct: Mapping from old coordinate names to new ones + :return: VMatrix with new names """ orig_name_mat = numpy.array(name_matrix(vma)) tril_idxs = numpy.tril_indices(orig_name_mat.shape[0], -1, m=3) @@ -586,7 +577,7 @@ def rename(vma, name_dct): return from_data(symbols(vma), key_matrix(vma), name_mat) -def standard_form(vma, shift=0): +def standard_form(vma: VMatrix, shift: int = 0) -> VMatrix: """Build a V-Matrix where all of the coordinate names of an input V-Matrix have been put into standard form: RN: (1<=N<=Ncoords) @@ -594,30 +585,29 @@ def standard_form(vma, shift=0): DN: (1<=N<=Ncoords). :param vma: V-Matrix - :type vma: automol V-Matrix data structure :param shift: value to shift the keys by when obtaining the keys - :type shift: int - :rtype: automol V-Matrix data strucutre + :return: V-Matrix of coordinate names """ name_mat = standard_name_matrix(vma, shift=shift) return set_name_matrix(vma, name_mat) # # # add/remove atoms -def add_atom(vma, symb, key_row, name_row=None, one_indexed=False): +def add_atom( + vma: VMatrix, + symb: str, + key_row: KeyRow, + name_row: NameRow = None, + one_indexed: bool = False, +) -> VMatrix: """Add an atom to a V-Matrix. :param vma: V-Matrix - :type vma: automol V-Matrix data structure - :param symb: symbol of atom to add - :type symb: str - :param key_row: row of keys to define new atom added to key matrix - :type key_row: tuple(int) - :param name_row: row of names to define new atom added to name matrix - :type name_row: tuple(str) - :param one_indexed: parameter to store keys in one-indexing - :type one_indexed: bool - :rtype: automol V-Matrix data structure + :param symb: Symbol of atom to add + :param key_row: Row of keys to define new atom added to key matrix + :param name_row: Row of names to define new atom added to name matrix + :param one_indexed: Parameter to store keys in one-indexing + :return: V-Matrix with new atom """ symbs = symbols(vma) symbs += (symb,) @@ -632,15 +622,13 @@ def add_atom(vma, symb, key_row, name_row=None, one_indexed=False): return vma -def remove_atom(vma, key): +def remove_atom(vma: VMatrix, key: Key) -> VMatrix: """Remove an atom from a V-Matrix. Error raised if attempting to remove atom other atoms depend on. :param vma: V-Matrix - :type vma: automol V-Matrix data structure :param key: key of atom to remove - :type key: str - :rtype: automol V-Matrix data structure + :return: V-Matrix without atom """ symbs = list(symbols(vma)) symbs.pop(key) @@ -664,12 +652,11 @@ def remove_atom(vma, key): # # # validation -def is_valid(vma): +def is_valid(vma: VMatrix) -> bool: """Assess if a V-Matrix has proper structure. :param vma: V-Matrix - :type vma: automol V-Matrix data structure - :rtype: bool + :return: True, if V-Matrix has proper structure; False otherwise """ ret = True try: @@ -682,28 +669,25 @@ def is_valid(vma): return ret -def is_standard_form(vma): +def is_standard_form(vma: VMatrix) -> bool: """Assesses if the names of the V-Matrix are in standard form: RN: (1<=N<=Ncoords) AN: (2<=N<=Ncoords) DN: (1<=N<=Ncoords). :param vma: V-Matrix - :type vma: automol V-Matrix data structure - :rtype: bool + :return: True if V-Matrix is in standard form, False if not """ return names(vma) == names(standard_form(vma)) # # # I/O -def string(vma, one_indexed=False): +def string(vma: VMatrix, one_indexed: bool = False) -> str: """Write a V-Matrix object to a string. :param vma: V-Matrix - :type vma: automol V-Matrix data structure - :param one_indexed: parameter to write keys in one-indexing - :type one_indexed: bool - :rtype: str + :param one_indexed: Parameter to write keys in one-indexing + :return: String """ shift = 1 if one_indexed else 0 symbs = symbols(vma) @@ -728,14 +712,12 @@ def _line_string(row_idx): return vma_str -def from_string(vma_str, one_indexed=None): +def from_string(vma_str: str, one_indexed: bool | None = None) -> VMatrix: """Parse a V-Matrix object from a string. :param vma_str: string containing a V-Matrix - :type vma_str: str :param one_indexed: Read a one-indexed string? - :type one_indexed: bool - :rtype: automol V-Matrix data structure + :return: V-Matrix of string """ rows = VMAT_LINES.parseString(vma_str).asList() symbs = [r.pop(0) for r in rows] @@ -748,17 +730,16 @@ def from_string(vma_str, one_indexed=None): # # helpers -def _key_matrix(key_mat, natms, one_indexed=None): +def _key_matrix( + key_mat: KeyMatrix, natms: int, one_indexed: bool | None = None +) -> KeyMatrix: """Build name matrix of the V-Matrix that contains the coordinate keys by row and column. :param key_mat: key matrix of V-Matrix coordinate keys - :type key_mat: tuple(tuple(int)) :param natms: number of atoms - :type natms: int :param one_indexed: parameter to write keys in one-indexing - :type one_indexed: bool - :rtype: tuple(tuple(str)) + :return: Name matrix of V-Matrix """ if natms == 1: return ((None, None, None),) @@ -780,15 +761,13 @@ def _key_matrix(key_mat, natms, one_indexed=None): return tuple(map(tuple, key_mat)) -def _name_matrix(name_mat, natms): +def _name_matrix(name_mat: NameMatrix, natms: int) -> NameMatrix: """Build name matrix of the V-Matrix that contains the coordinate names by row and column. :param name_mat: key matrix of V-Matrix coordinate keys - :type name_mat: tuple(tuple(int)) :param natms: number of atoms - :type natms: int - :rtype: tuple(tuple(str)) + :return: Name matrix of Vmatrix """ if name_mat is None: name_mat = numpy.empty((natms, 3), dtype=object) @@ -815,12 +794,11 @@ def _name_matrix(name_mat, natms): return tuple(map(tuple, name_mat)) -def _is_sequence_of_triples(obj): +def _is_sequence_of_triples(obj) -> bool: """Assess if input object sequence has length of three. - :param obj: object with __len__ attribute - :type obj: list, tuple, dict - :rtype: bool + :param obj: Object with __len__ attribute + :return: True if object is less than len=3, False if not """ ret = hasattr(obj, "__len__") if ret: